Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +38 -35
- .gitignore +14 -0
- .vscode/settings.json +3 -0
- ARCHITECTURE.md +349 -0
- FIX_MEMORY_ISSUE.md +79 -0
- PRESENTATION_SCRIPT.md +345 -0
- QUICKSTART.md +225 -0
- README.md +81 -12
- UI_GUIDE.md +185 -0
- UI_IMPLEMENTATION.md +282 -0
- UI_README.md +172 -0
- UI_SUMMARY.md +405 -0
- __init__.py +0 -0
- adapters/build_corpora.py +126 -0
- adapters/pubmed.py +132 -0
- app_retrieval_cached.py +376 -0
- classifier/__init__.py +0 -0
- classifier/config.py +13 -0
- classifier/head.py +85 -0
- classifier/infer.py +86 -0
- classifier/modelcard_template.md +200 -0
- classifier/query_router.py +383 -0
- classifier/reason/README.md +311 -0
- classifier/reason/__init__.py +17 -0
- classifier/reason/infer_reason.py +209 -0
- classifier/reason/reason_classifier.py +366 -0
- classifier/reason/train_reason.py +224 -0
- classifier/train.py +324 -0
- classifier/utils.py +75 -0
- cli/healthcare_classifier_cli.py +317 -0
- cli/reason_classifier_cli.py +202 -0
- config.py +35 -0
- environment.yml +174 -0
- launch_ui.bat +28 -0
- launch_ui.ps1 +33 -0
- main.py +78 -0
- pipeline.py +83 -0
- reason_data_analysis.py +81 -0
- requirements-admin.txt +9 -0
- requirements-train.txt +6 -0
- requirements.txt +14 -0
- retriever/__init__.py +1 -0
- retriever/data_schemas.py +8 -0
- retriever/index_bm25.py +14 -0
- retriever/index_dense.py +217 -0
- retriever/ingest.py +22 -0
- retriever/rrf.py +11 -0
- retriever/search.py +43 -0
- retriever/utils.py +15 -0
- scripts/query.py +58 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,38 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
frontend/public/assets/jordan.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
frontend/public/assets/sacha.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
frontend/public/assets/alex.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cache/
|
| 2 |
+
env.list
|
| 3 |
+
__pycache__
|
| 4 |
+
**/__pycache__
|
| 5 |
+
frontend/node_modules
|
| 6 |
+
frontend/build
|
| 7 |
+
.git
|
| 8 |
+
data/
|
| 9 |
+
indexes/
|
| 10 |
+
classifier/checkpoints/
|
| 11 |
+
.DS_Store
|
| 12 |
+
.cache
|
| 13 |
+
.venv/
|
| 14 |
+
.gradio/
|
.vscode/settings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"python.analysis.typeCheckingMode": "standard"
|
| 3 |
+
}
|
ARCHITECTURE.md
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Medical Q&A Bot - System Architecture
|
| 2 |
+
|
| 3 |
+
## Visual Overview
|
| 4 |
+
|
| 5 |
+
```
|
| 6 |
+
┌─────────────────────────────────────────────────────────────────┐
|
| 7 |
+
│ USER INTERFACE │
|
| 8 |
+
│ │
|
| 9 |
+
│ ┌──────────────────────┐ ┌──────────────────────┐ │
|
| 10 |
+
│ │ Gradio Web UI │ │ Streamlit Web UI │ │
|
| 11 |
+
│ │ (app.py) │ OR │ (app_streamlit.py) │ │
|
| 12 |
+
│ │ Port: 7860 │ │ Port: 8501 │ │
|
| 13 |
+
│ └──────────┬───────────┘ └──────────┬───────────┘ │
|
| 14 |
+
└─────────────┼────────────────────────────────┼─────────────────┘
|
| 15 |
+
│ │
|
| 16 |
+
└────────────────┬───────────────┘
|
| 17 |
+
│
|
| 18 |
+
▼
|
| 19 |
+
┌────────────────────────────────┐
|
| 20 |
+
│ Query Processing Layer │
|
| 21 |
+
│ │
|
| 22 |
+
│ 1. Text Input Validation │
|
| 23 |
+
│ 2. Embedding Generation │
|
| 24 |
+
│ 3. Model Inference │
|
| 25 |
+
└────────────┬───────────────────┘
|
| 26 |
+
│
|
| 27 |
+
▼
|
| 28 |
+
┌────────────────────────────────┐
|
| 29 |
+
│ CLASSIFIER MODULE │
|
| 30 |
+
│ (classifier/) │
|
| 31 |
+
│ │
|
| 32 |
+
│ ┌──────────────────────────┐ │
|
| 33 |
+
│ │ SentenceTransformer │ │
|
| 34 |
+
│ │ Embedding Model │ │
|
| 35 |
+
│ └───────────┬──────────────┘ │
|
| 36 |
+
│ │ │
|
| 37 |
+
│ ▼ │
|
| 38 |
+
│ ┌──────────────────────────┐ │
|
| 39 |
+
│ │ Classification Head │ │
|
| 40 |
+
│ │ (Neural Network) │ │
|
| 41 |
+
│ └───────────┬──────────────┘ │
|
| 42 |
+
└──────────────┼─────────────────┘
|
| 43 |
+
│
|
| 44 |
+
┌──────────┴──────────┐
|
| 45 |
+
│ │
|
| 46 |
+
┌────────▼────────┐ ┌───────▼────────┐
|
| 47 |
+
│ MEDICAL │ │ ADMINISTRATIVE│
|
| 48 |
+
│ QUERY │ │ QUERY │
|
| 49 |
+
└────────┬────────┘ └───────┬────────┘
|
| 50 |
+
│ │
|
| 51 |
+
│ └──► End (No Retrieval)
|
| 52 |
+
│
|
| 53 |
+
▼
|
| 54 |
+
┌─────────────────────────────────┐
|
| 55 |
+
│ RETRIEVAL MODULE │
|
| 56 |
+
│ (retriever/) │
|
| 57 |
+
│ │
|
| 58 |
+
│ ┌────────────────────────┐ │
|
| 59 |
+
│ │ BM25 Search │ │
|
| 60 |
+
│ │ (Sparse Retrieval) │ │
|
| 61 |
+
│ └───────────┬────────────┘ │
|
| 62 |
+
│ │ │
|
| 63 |
+
│ ┌───────────▼────────────┐ │
|
| 64 |
+
│ │ Dense Search │ │
|
| 65 |
+
│ │ (Vector Similarity) │ │
|
| 66 |
+
│ └───────────┬────────────┘ │
|
| 67 |
+
│ │ │
|
| 68 |
+
│ ┌───────────▼────────────┐ │
|
| 69 |
+
│ │ RRF Fusion │ │
|
| 70 |
+
│ │ (Rank Combination) │ │
|
| 71 |
+
│ └───��───────┬────────────┘ │
|
| 72 |
+
│ │ │
|
| 73 |
+
│ ┌───────────▼────────────┐ │
|
| 74 |
+
│ │ Optional Reranker │ │
|
| 75 |
+
│ │ (Cross-Encoder) │ │
|
| 76 |
+
│ └───────────┬────────────┘ │
|
| 77 |
+
└──────────────┼─────────────────┘
|
| 78 |
+
│
|
| 79 |
+
▼
|
| 80 |
+
┌───────────────────────┐
|
| 81 |
+
│ DATA SOURCES │
|
| 82 |
+
│ │
|
| 83 |
+
│ • PubMed Articles │
|
| 84 |
+
│ • Miriad Q&A │
|
| 85 |
+
│ • UniDoc Q&A │
|
| 86 |
+
│ │
|
| 87 |
+
│ (data/corpora/) │
|
| 88 |
+
└───────────┬───────────┘
|
| 89 |
+
│
|
| 90 |
+
▼
|
| 91 |
+
┌───────────────────────┐
|
| 92 |
+
│ RESULTS │
|
| 93 |
+
│ │
|
| 94 |
+
│ • Document Title │
|
| 95 |
+
│ • Text Content │
|
| 96 |
+
│ • Relevance Scores │
|
| 97 |
+
│ • Metadata │
|
| 98 |
+
└───────────┬───────────┘
|
| 99 |
+
│
|
| 100 |
+
▼
|
| 101 |
+
┌───────────────────────┐
|
| 102 |
+
│ UI DISPLAY │
|
| 103 |
+
│ │
|
| 104 |
+
│ • Formatted Cards │
|
| 105 |
+
│ • JSON View │
|
| 106 |
+
│ • Score Badges │
|
| 107 |
+
└───────────────────────┘
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
## Data Flow
|
| 111 |
+
|
| 112 |
+
### 1. User Input
|
| 113 |
+
```
|
| 114 |
+
User Types Query → Web Interface Captures Input → Sends to Backend
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
### 2. Classification Phase
|
| 118 |
+
```
|
| 119 |
+
Query Text
|
| 120 |
+
↓
|
| 121 |
+
Sentence Transformer (Embedding)
|
| 122 |
+
↓
|
| 123 |
+
Classification Head (Neural Network)
|
| 124 |
+
↓
|
| 125 |
+
Output: [Medical | Administrative | Other] + Confidence Scores
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### 3. Retrieval Phase (Medical Queries Only)
|
| 129 |
+
```
|
| 130 |
+
Medical Query
|
| 131 |
+
↓
|
| 132 |
+
┌────────────────────────┐
|
| 133 |
+
│ Parallel Retrieval │
|
| 134 |
+
│ ┌─────────────────┐ │
|
| 135 |
+
│ │ BM25 (Sparse) │ │ ← Top 100 docs
|
| 136 |
+
│ └─────────────────┘ │
|
| 137 |
+
│ ┌─────────────────┐ │
|
| 138 |
+
│ │ Dense (Vector) │ │ ← Top 100 docs
|
| 139 |
+
│ └─────────────────┘ │
|
| 140 |
+
└────────────────────────┘
|
| 141 |
+
↓
|
| 142 |
+
RRF Fusion Algorithm
|
| 143 |
+
↓
|
| 144 |
+
Top K Candidates
|
| 145 |
+
↓
|
| 146 |
+
Optional: Cross-Encoder Reranking
|
| 147 |
+
↓
|
| 148 |
+
Final Top N Results
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
## Technology Stack
|
| 152 |
+
|
| 153 |
+
### Frontend
|
| 154 |
+
- **Gradio** - Primary UI framework
|
| 155 |
+
- **Streamlit** - Alternative UI framework
|
| 156 |
+
- **HTML/CSS** - Custom styling
|
| 157 |
+
- **JavaScript** - Auto-generated by frameworks
|
| 158 |
+
|
| 159 |
+
### Backend
|
| 160 |
+
- **Python 3.8+** - Core language
|
| 161 |
+
- **PyTorch** - Deep learning framework
|
| 162 |
+
- **Sentence-Transformers** - Embedding models
|
| 163 |
+
- **scikit-learn** - ML utilities
|
| 164 |
+
|
| 165 |
+
### Search & Retrieval
|
| 166 |
+
- **Rank-BM25** - Sparse retrieval
|
| 167 |
+
- **FAISS** - Dense vector search
|
| 168 |
+
- **Custom RRF** - Rank fusion
|
| 169 |
+
- **Cross-Encoder** - Optional reranking
|
| 170 |
+
|
| 171 |
+
### Data
|
| 172 |
+
- **PubMed** - Medical research articles
|
| 173 |
+
- **Miriad** - Medical Q&A database
|
| 174 |
+
- **UniDoc** - Unified document corpus
|
| 175 |
+
- **JSONL** - Data storage format
|
| 176 |
+
|
| 177 |
+
## Component Interactions
|
| 178 |
+
|
| 179 |
+
### 1. Initialization
|
| 180 |
+
```python
|
| 181 |
+
# Load models once at startup
|
| 182 |
+
embedding_model, classifier = classifier_init()
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
### 2. Classification
|
| 186 |
+
```python
|
| 187 |
+
classification = predict_query(
|
| 188 |
+
text=[query],
|
| 189 |
+
embedding_model=embedding_model,
|
| 190 |
+
classifier_head=classifier
|
| 191 |
+
)
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
### 3. Retrieval
|
| 195 |
+
```python
|
| 196 |
+
hits = get_candidates(
|
| 197 |
+
query=query,
|
| 198 |
+
k_retrieve=10,
|
| 199 |
+
use_reranker=False
|
| 200 |
+
)
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
### 4. Display
|
| 204 |
+
```python
|
| 205 |
+
# Gradio displays results in tabs
|
| 206 |
+
# - Formatted HTML view
|
| 207 |
+
# - Raw JSON view
|
| 208 |
+
```
|
| 209 |
+
|
| 210 |
+
## Performance Characteristics
|
| 211 |
+
|
| 212 |
+
### Speed
|
| 213 |
+
- **Classification**: ~100-500ms
|
| 214 |
+
- **BM25 Search**: ~50-200ms
|
| 215 |
+
- **Dense Search**: ~100-300ms
|
| 216 |
+
- **Reranking**: ~500-2000ms (if enabled)
|
| 217 |
+
|
| 218 |
+
### Accuracy
|
| 219 |
+
- **Classification**: ~95% accuracy
|
| 220 |
+
- **Retrieval**: Depends on corpus and query
|
| 221 |
+
- **Reranking**: +5-10% improvement
|
| 222 |
+
|
| 223 |
+
### Resource Usage
|
| 224 |
+
- **Memory**: ~2-4 GB (with models loaded)
|
| 225 |
+
- **CPU**: Moderate during inference
|
| 226 |
+
- **GPU**: Optional (speeds up inference)
|
| 227 |
+
|
| 228 |
+
## Scalability Considerations
|
| 229 |
+
|
| 230 |
+
### Current Setup (Single User)
|
| 231 |
+
- ✅ Perfect for demos and development
|
| 232 |
+
- ✅ Low latency
|
| 233 |
+
- ✅ Easy to debug
|
| 234 |
+
|
| 235 |
+
### Future Scaling Options
|
| 236 |
+
- 🔄 Add caching for common queries
|
| 237 |
+
- 🔄 Deploy on cloud with autoscaling
|
| 238 |
+
- 🔄 Use model quantization for faster inference
|
| 239 |
+
- 🔄 Implement request queuing
|
| 240 |
+
- 🔄 Add load balancing
|
| 241 |
+
|
| 242 |
+
## Security & Privacy
|
| 243 |
+
|
| 244 |
+
### Current Implementation
|
| 245 |
+
- Local hosting only
|
| 246 |
+
- No data persistence
|
| 247 |
+
- No user tracking
|
| 248 |
+
- No authentication (optional)
|
| 249 |
+
|
| 250 |
+
### Production Considerations
|
| 251 |
+
- Add user authentication
|
| 252 |
+
- Implement rate limiting
|
| 253 |
+
- Sanitize inputs
|
| 254 |
+
- Log access for auditing
|
| 255 |
+
- HTTPS for encrypted communication
|
| 256 |
+
|
| 257 |
+
## Monitoring & Debugging
|
| 258 |
+
|
| 259 |
+
### Available Information
|
| 260 |
+
- Query classification results
|
| 261 |
+
- Confidence scores per category
|
| 262 |
+
- Retrieval scores (BM25, Dense, RRF)
|
| 263 |
+
- Document metadata
|
| 264 |
+
- Error messages
|
| 265 |
+
|
| 266 |
+
### Debug Mode
|
| 267 |
+
```python
|
| 268 |
+
# In app.py, set:
|
| 269 |
+
demo.launch(show_error=True) # Shows detailed errors
|
| 270 |
+
```
|
| 271 |
+
|
| 272 |
+
## Deployment Options
|
| 273 |
+
|
| 274 |
+
### 1. Local (Current)
|
| 275 |
+
```
|
| 276 |
+
Pros: Easy, fast, secure
|
| 277 |
+
Cons: Single user, not accessible remotely
|
| 278 |
+
```
|
| 279 |
+
|
| 280 |
+
### 2. Hugging Face Spaces
|
| 281 |
+
```
|
| 282 |
+
Pros: Free, easy deploy, public URL
|
| 283 |
+
Cons: Limited resources, public access
|
| 284 |
+
```
|
| 285 |
+
|
| 286 |
+
### 3. Cloud (AWS/GCP/Azure)
|
| 287 |
+
```
|
| 288 |
+
Pros: Scalable, private, customizable
|
| 289 |
+
Cons: Costs money, requires setup
|
| 290 |
+
```
|
| 291 |
+
|
| 292 |
+
### 4. Docker Container
|
| 293 |
+
```
|
| 294 |
+
Pros: Portable, consistent environment
|
| 295 |
+
Cons: Requires Docker knowledge
|
| 296 |
+
```
|
| 297 |
+
|
| 298 |
+
## File Structure
|
| 299 |
+
|
| 300 |
+
```
|
| 301 |
+
health-query-classifier/
|
| 302 |
+
├── 🖥️ UI Layer
|
| 303 |
+
│ ├── app.py # Main Gradio UI
|
| 304 |
+
│ ├── app_streamlit.py # Alternative Streamlit UI
|
| 305 |
+
│ ├── launch_ui.bat # Windows launcher
|
| 306 |
+
│ └── launch_ui.ps1 # PowerShell launcher
|
| 307 |
+
│
|
| 308 |
+
├── 🧠 Classifier Layer
|
| 309 |
+
│ ├── classifier/
|
| 310 |
+
│ │ ├── infer.py # Inference logic
|
| 311 |
+
│ │ ├── head.py # Classification head
|
| 312 |
+
│ │ ├── train.py # Training script
|
| 313 |
+
│ │ └── utils.py # Utilities
|
| 314 |
+
│
|
| 315 |
+
├── 🔍 Retrieval Layer
|
| 316 |
+
│ ├── retriever/
|
| 317 |
+
│ │ ├── search.py # Search interface
|
| 318 |
+
│ │ ├── index_bm25.py # BM25 indexing
|
| 319 |
+
│ │ ├── index_dense.py # Dense indexing
|
| 320 |
+
│ │ └── rrf.py # Rank fusion
|
| 321 |
+
│
|
| 322 |
+
├── 👥 Team Layer
|
| 323 |
+
│ ├── team/
|
| 324 |
+
│ │ ├── candidates.py # Candidate retrieval
|
| 325 |
+
│ │ └── interfaces.py # Data interfaces
|
| 326 |
+
│
|
| 327 |
+
├── 📊 Data Layer
|
| 328 |
+
│ ├── data/
|
| 329 |
+
│ │ └── corpora/ # Corpus files
|
| 330 |
+
│ │ ├── medical_qa.jsonl
|
| 331 |
+
│ │ ├── miriad_text.jsonl
|
| 332 |
+
│ │ └── unidoc_qa.jsonl
|
| 333 |
+
│
|
| 334 |
+
└── 📚 Documentation
|
| 335 |
+
├── README.md # Main documentation
|
| 336 |
+
├── QUICKSTART.md # Quick start guide
|
| 337 |
+
├── UI_README.md # UI documentation
|
| 338 |
+
├── UI_IMPLEMENTATION.md # Implementation details
|
| 339 |
+
└── ARCHITECTURE.md # This file
|
| 340 |
+
```
|
| 341 |
+
|
| 342 |
+
---
|
| 343 |
+
|
| 344 |
+
This architecture ensures:
|
| 345 |
+
- ✅ Clean separation of concerns
|
| 346 |
+
- ✅ Modular design
|
| 347 |
+
- ✅ Easy to test and debug
|
| 348 |
+
- ✅ Scalable and maintainable
|
| 349 |
+
- ✅ Well-documented
|
FIX_MEMORY_ISSUE.md
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fixing Memory Issue - Windows Virtual Memory
|
| 2 |
+
|
| 3 |
+
## Problem
|
| 4 |
+
```
|
| 5 |
+
OSError: The paging file is too small for this operation to complete. (os error 1455)
|
| 6 |
+
```
|
| 7 |
+
|
| 8 |
+
Your system needs more virtual memory to load the large AI models (1.21GB+).
|
| 9 |
+
|
| 10 |
+
## Solution: Increase Windows Virtual Memory
|
| 11 |
+
|
| 12 |
+
### Step-by-Step Instructions:
|
| 13 |
+
|
| 14 |
+
1. **Open System Properties**
|
| 15 |
+
- Press `Windows Key + Pause/Break` OR
|
| 16 |
+
- Right-click "This PC" → Properties → Advanced system settings
|
| 17 |
+
|
| 18 |
+
2. **Access Virtual Memory Settings**
|
| 19 |
+
- Click "Advanced" tab
|
| 20 |
+
- Under "Performance", click "Settings..."
|
| 21 |
+
- Click "Advanced" tab again
|
| 22 |
+
- Under "Virtual memory", click "Change..."
|
| 23 |
+
|
| 24 |
+
3. **Configure Virtual Memory**
|
| 25 |
+
- **Uncheck** "Automatically manage paging file size for all drives"
|
| 26 |
+
- Select your C: drive (or main drive)
|
| 27 |
+
- Select "Custom size"
|
| 28 |
+
- Set values:
|
| 29 |
+
- **Initial size (MB):** 8192 (8 GB)
|
| 30 |
+
- **Maximum size (MB):** 16384 (16 GB)
|
| 31 |
+
- Click "Set"
|
| 32 |
+
- Click "OK" on all dialogs
|
| 33 |
+
|
| 34 |
+
4. **Restart Your Computer**
|
| 35 |
+
- This is required for changes to take effect
|
| 36 |
+
|
| 37 |
+
5. **Try Running the App Again**
|
| 38 |
+
```powershell
|
| 39 |
+
python app.py
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## Alternative: Quick Fix (Temporary)
|
| 43 |
+
|
| 44 |
+
If you can't change virtual memory settings, try these:
|
| 45 |
+
|
| 46 |
+
### Option A: Close Other Programs
|
| 47 |
+
- Close all browsers, apps, and programs
|
| 48 |
+
- This frees up RAM
|
| 49 |
+
- Then try running the app again
|
| 50 |
+
|
| 51 |
+
### Option B: Use Smaller Model (Code Change)
|
| 52 |
+
Edit `classifier/config.py` to use a smaller model if available.
|
| 53 |
+
|
| 54 |
+
### Option C: Run with Priority
|
| 55 |
+
```powershell
|
| 56 |
+
# Run Python with higher priority
|
| 57 |
+
Start-Process python -ArgumentList "app.py" -WindowStyle Normal -Wait
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
## Checking Current Virtual Memory
|
| 61 |
+
|
| 62 |
+
To see your current settings:
|
| 63 |
+
1. Follow steps 1-2 above
|
| 64 |
+
2. Note the current "Total paging file size" at the bottom
|
| 65 |
+
|
| 66 |
+
Typical recommendations:
|
| 67 |
+
- **Minimum:** 1.5x your RAM
|
| 68 |
+
- **Recommended:** 2-3x your RAM
|
| 69 |
+
- **For ML models:** At least 8-16 GB
|
| 70 |
+
|
| 71 |
+
## After Fixing
|
| 72 |
+
|
| 73 |
+
Once virtual memory is increased and computer is restarted:
|
| 74 |
+
```powershell
|
| 75 |
+
cd "C:\Users\Tarak Jha\OneDrive - Coast to Coast Logistics\Desktop\HEALTHBOT\health-query-classifier"
|
| 76 |
+
python app.py
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
The models should load successfully!
|
PRESENTATION_SCRIPT.md
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Medical Q&A Bot - Presentation Script
|
| 2 |
+
|
| 3 |
+
## 🎯 Presentation Overview (10-15 minutes)
|
| 4 |
+
|
| 5 |
+
### Team Introduction (30 seconds)
|
| 6 |
+
"Hello everyone! We're Team HealthBot, and we've developed an intelligent medical query classification and research retrieval system. Our team consists of:
|
| 7 |
+
- David Gray
|
| 8 |
+
- Tarak Jha
|
| 9 |
+
- Sravani Segireddy
|
| 10 |
+
- Riley Millikan
|
| 11 |
+
- Kent R. Spillner"
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
## Part 1: Problem Statement (1-2 minutes)
|
| 16 |
+
|
| 17 |
+
### The Challenge
|
| 18 |
+
"In healthcare settings, patients often have questions that fall into two categories:
|
| 19 |
+
1. **Medical queries** - Questions requiring clinical expertise
|
| 20 |
+
2. **Administrative queries** - Questions about billing, scheduling, etc.
|
| 21 |
+
|
| 22 |
+
Currently, all queries are handled the same way, leading to:
|
| 23 |
+
- ❌ Inefficient triage
|
| 24 |
+
- ❌ Delayed responses
|
| 25 |
+
- ❌ Wasted resources
|
| 26 |
+
- ❌ Frustrated patients and staff"
|
| 27 |
+
|
| 28 |
+
### Our Solution
|
| 29 |
+
"We built an AI-powered system that:
|
| 30 |
+
1. ✅ Automatically classifies queries
|
| 31 |
+
2. ✅ Retrieves relevant medical research for medical queries
|
| 32 |
+
3. ✅ Provides confidence scores for transparency
|
| 33 |
+
4. ✅ Offers a user-friendly web interface"
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
## Part 2: Technical Architecture (2-3 minutes)
|
| 38 |
+
|
| 39 |
+
### System Overview
|
| 40 |
+
[Show ARCHITECTURE.md diagram]
|
| 41 |
+
|
| 42 |
+
"Our system operates in two main stages:
|
| 43 |
+
|
| 44 |
+
**Stage 1: Classification**
|
| 45 |
+
- Uses a fine-tuned sentence transformer model
|
| 46 |
+
- Classifies queries as Medical, Administrative, or Other
|
| 47 |
+
- Provides confidence scores for each category
|
| 48 |
+
|
| 49 |
+
**Stage 2: Retrieval** (Medical queries only)
|
| 50 |
+
- Implements hybrid search combining:
|
| 51 |
+
- BM25 (keyword-based sparse retrieval)
|
| 52 |
+
- Dense embeddings (semantic similarity)
|
| 53 |
+
- RRF (Reciprocal Rank Fusion) for combining results
|
| 54 |
+
- Optional cross-encoder reranking for improved accuracy"
|
| 55 |
+
|
| 56 |
+
### Data Sources
|
| 57 |
+
"We index three major medical databases:
|
| 58 |
+
- **PubMed**: Peer-reviewed medical research
|
| 59 |
+
- **Miriad**: Medical Q&A database
|
| 60 |
+
- **UniDoc**: Unified medical document corpus
|
| 61 |
+
|
| 62 |
+
This gives us access to thousands of verified medical documents."
|
| 63 |
+
|
| 64 |
+
---
|
| 65 |
+
|
| 66 |
+
## Part 3: Live Demo (5-7 minutes)
|
| 67 |
+
|
| 68 |
+
### Setup
|
| 69 |
+
"Let me show you how it works in practice. We've built a web interface using Gradio."
|
| 70 |
+
[Open http://127.0.0.1:7860]
|
| 71 |
+
|
| 72 |
+
### Demo 1: Medical Query (2 minutes)
|
| 73 |
+
"Let's start with a medical question:"
|
| 74 |
+
|
| 75 |
+
**Type**: "I'm having a really bad rash on my hands. I'm pretty sure it's my eczema flaring up. Is there anything stronger than aquaphor I can use on it?"
|
| 76 |
+
|
| 77 |
+
**Point out**:
|
| 78 |
+
1. "Notice the system classified this as a MEDICAL query"
|
| 79 |
+
2. "Look at the confidence scores - 95% confidence it's medical"
|
| 80 |
+
3. "The system retrieved 10 relevant documents from our medical databases"
|
| 81 |
+
4. "Each document shows multiple relevance scores:"
|
| 82 |
+
- "BM25 score for keyword matching"
|
| 83 |
+
- "Dense score for semantic similarity"
|
| 84 |
+
- "RRF score for combined ranking"
|
| 85 |
+
5. "We can see the document titles, previews, and full metadata"
|
| 86 |
+
|
| 87 |
+
### Demo 2: Administrative Query (1 minute)
|
| 88 |
+
"Now let's try an administrative question:"
|
| 89 |
+
|
| 90 |
+
**Type**: "Hey is there any way I can get an appointment in the next month?"
|
| 91 |
+
|
| 92 |
+
**Point out**:
|
| 93 |
+
1. "The system correctly identified this as ADMINISTRATIVE"
|
| 94 |
+
2. "No document retrieval happens - saving resources"
|
| 95 |
+
3. "This query would be routed to scheduling staff, not medical staff"
|
| 96 |
+
|
| 97 |
+
### Demo 3: Medical Emergency (1 minute)
|
| 98 |
+
"Here's a more urgent medical case:"
|
| 99 |
+
|
| 100 |
+
**Type**: "worst headache of my life with fever and stiff neck"
|
| 101 |
+
|
| 102 |
+
**Point out**:
|
| 103 |
+
1. "Classified as MEDICAL with high confidence"
|
| 104 |
+
2. "Retrieved relevant documents about meningitis symptoms"
|
| 105 |
+
3. "This demonstrates the system can handle urgent queries"
|
| 106 |
+
4. "In a real setting, this could trigger an emergency protocol"
|
| 107 |
+
|
| 108 |
+
### Demo 4: Advanced Features (1 minute)
|
| 109 |
+
"Let me show you some advanced features:"
|
| 110 |
+
|
| 111 |
+
**Adjust settings**:
|
| 112 |
+
1. Change "Number of Results" to 20
|
| 113 |
+
2. Enable "Use Reranker"
|
| 114 |
+
|
| 115 |
+
**Type**: "What are the side effects of statins?"
|
| 116 |
+
|
| 117 |
+
**Point out**:
|
| 118 |
+
1. "We can control how many results to retrieve"
|
| 119 |
+
2. "The reranker improves accuracy but takes longer"
|
| 120 |
+
3. "We have both formatted view and JSON view for different audiences"
|
| 121 |
+
|
| 122 |
+
---
|
| 123 |
+
|
| 124 |
+
## Part 4: Technical Implementation (2-3 minutes)
|
| 125 |
+
|
| 126 |
+
### Machine Learning Models
|
| 127 |
+
"Under the hood, we use:
|
| 128 |
+
- **Sentence Transformers**: For generating semantic embeddings
|
| 129 |
+
- **Custom Classification Head**: Neural network trained on healthcare data
|
| 130 |
+
- **FAISS**: For efficient vector similarity search
|
| 131 |
+
- **Cross-Encoder**: Optional reranking for accuracy"
|
| 132 |
+
|
| 133 |
+
### User Interface
|
| 134 |
+
"We implemented two web interfaces:
|
| 135 |
+
1. **Gradio** (primary) - Clean, professional, easy to deploy
|
| 136 |
+
2. **Streamlit** (alternative) - More interactive and customizable
|
| 137 |
+
|
| 138 |
+
Both provide:
|
| 139 |
+
- Real-time classification and retrieval
|
| 140 |
+
- Multiple view modes (formatted and JSON)
|
| 141 |
+
- Adjustable settings
|
| 142 |
+
- Example queries for easy testing"
|
| 143 |
+
|
| 144 |
+
### Code Quality
|
| 145 |
+
"Our codebase demonstrates:
|
| 146 |
+
- ✅ Modular design with clear separation of concerns
|
| 147 |
+
- ✅ Comprehensive documentation
|
| 148 |
+
- ✅ Easy setup and deployment
|
| 149 |
+
- ✅ Error handling and validation
|
| 150 |
+
- ✅ Scalable architecture"
|
| 151 |
+
|
| 152 |
+
---
|
| 153 |
+
|
| 154 |
+
## Part 5: Results & Impact (1-2 minutes)
|
| 155 |
+
|
| 156 |
+
### Performance Metrics
|
| 157 |
+
"Our system achieves:
|
| 158 |
+
- **Classification Accuracy**: ~95%
|
| 159 |
+
- **Response Time**: <1 second for most queries
|
| 160 |
+
- **Retrieval Quality**: High relevance in top results
|
| 161 |
+
- **User Experience**: Clean, intuitive interface"
|
| 162 |
+
|
| 163 |
+
### Real-World Impact
|
| 164 |
+
"This system could:
|
| 165 |
+
1. 📊 Reduce triage time by 60-80%
|
| 166 |
+
2. 💰 Save healthcare costs through efficient routing
|
| 167 |
+
3. 🎯 Improve patient satisfaction with faster responses
|
| 168 |
+
4. 📚 Empower patients with evidence-based information
|
| 169 |
+
5. 👨⚕️ Help doctors by providing relevant research context"
|
| 170 |
+
|
| 171 |
+
---
|
| 172 |
+
|
| 173 |
+
## Part 6: Future Enhancements (1 minute)
|
| 174 |
+
|
| 175 |
+
### Potential Improvements
|
| 176 |
+
"Moving forward, we could add:
|
| 177 |
+
- 🔐 User authentication and personalization
|
| 178 |
+
- 📱 Mobile app for patient use
|
| 179 |
+
- 🌍 Multi-language support
|
| 180 |
+
- 📊 Analytics dashboard for healthcare providers
|
| 181 |
+
- 🔗 Integration with existing EMR systems
|
| 182 |
+
- 🗣️ Voice input for accessibility
|
| 183 |
+
- 📈 Continuous learning from user feedback"
|
| 184 |
+
|
| 185 |
+
---
|
| 186 |
+
|
| 187 |
+
## Part 7: Conclusion (30 seconds)
|
| 188 |
+
|
| 189 |
+
### Summary
|
| 190 |
+
"In summary, we've built an intelligent medical query classification and retrieval system that:
|
| 191 |
+
- ✅ Automatically triages patient queries
|
| 192 |
+
- ✅ Retrieves relevant medical research
|
| 193 |
+
- ✅ Provides a professional web interface
|
| 194 |
+
- ✅ Can be easily deployed in real healthcare settings
|
| 195 |
+
|
| 196 |
+
This represents a practical application of AI in healthcare that can improve efficiency and patient outcomes."
|
| 197 |
+
|
| 198 |
+
### Q&A
|
| 199 |
+
"Thank you! We're happy to answer any questions."
|
| 200 |
+
|
| 201 |
+
---
|
| 202 |
+
|
| 203 |
+
## 🎯 Tips for Presenters
|
| 204 |
+
|
| 205 |
+
### Before Presentation
|
| 206 |
+
1. ✅ Test the web UI beforehand
|
| 207 |
+
2. ✅ Have example queries ready
|
| 208 |
+
3. ✅ Check internet connection (for model loading)
|
| 209 |
+
4. ✅ Prepare backup slides in case of technical issues
|
| 210 |
+
5. ✅ Practice the demo flow multiple times
|
| 211 |
+
6. ✅ Assign roles (who presents what)
|
| 212 |
+
|
| 213 |
+
### During Presentation
|
| 214 |
+
1. ✅ Speak clearly and at a steady pace
|
| 215 |
+
2. ✅ Make eye contact with audience
|
| 216 |
+
3. ✅ Explain technical terms briefly
|
| 217 |
+
4. ✅ Show enthusiasm about the project
|
| 218 |
+
5. ✅ Be ready to handle unexpected results
|
| 219 |
+
6. ✅ Keep demo queries visible on screen
|
| 220 |
+
|
| 221 |
+
### Handling Questions
|
| 222 |
+
|
| 223 |
+
**Common Questions & Answers**:
|
| 224 |
+
|
| 225 |
+
Q: "How accurate is the classification?"
|
| 226 |
+
A: "Our classifier achieves approximately 95% accuracy on our test set, with particularly high precision for medical queries."
|
| 227 |
+
|
| 228 |
+
Q: "What about patient privacy?"
|
| 229 |
+
A: "Currently, this is a prototype that doesn't store any data. In production, we'd implement HIPAA-compliant data handling."
|
| 230 |
+
|
| 231 |
+
Q: "How do you handle ambiguous queries?"
|
| 232 |
+
A: "The system provides confidence scores for each category. Low-confidence queries could be flagged for human review."
|
| 233 |
+
|
| 234 |
+
Q: "Can it handle emergency situations?"
|
| 235 |
+
A: "Yes, medical queries can be analyzed for urgency. In production, high-urgency keywords could trigger immediate alerts."
|
| 236 |
+
|
| 237 |
+
Q: "What databases do you use?"
|
| 238 |
+
A: "We index PubMed articles, Miriad medical Q&A, and UniDoc corpus - all verified medical sources."
|
| 239 |
+
|
| 240 |
+
Q: "How long did this take to build?"
|
| 241 |
+
A: "The project took [X weeks/months], including data preparation, model training, and UI development."
|
| 242 |
+
|
| 243 |
+
Q: "Could this be deployed in a real hospital?"
|
| 244 |
+
A: "Absolutely! It would require integration with existing systems, compliance verification, and additional security features."
|
| 245 |
+
|
| 246 |
+
---
|
| 247 |
+
|
| 248 |
+
## 📊 Suggested Slides
|
| 249 |
+
|
| 250 |
+
### Slide 1: Title
|
| 251 |
+
- Project name
|
| 252 |
+
- Team members
|
| 253 |
+
- Course/institution
|
| 254 |
+
|
| 255 |
+
### Slide 2: Problem Statement
|
| 256 |
+
- Current challenges in healthcare
|
| 257 |
+
- Need for automated triage
|
| 258 |
+
|
| 259 |
+
### Slide 3: Solution Overview
|
| 260 |
+
- Two-stage system (classify + retrieve)
|
| 261 |
+
- Key benefits
|
| 262 |
+
|
| 263 |
+
### Slide 4: Architecture Diagram
|
| 264 |
+
- Visual flow chart
|
| 265 |
+
- Key components
|
| 266 |
+
|
| 267 |
+
### Slide 5: Technical Stack
|
| 268 |
+
- ML models used
|
| 269 |
+
- Frameworks and tools
|
| 270 |
+
- Data sources
|
| 271 |
+
|
| 272 |
+
### Slide 6: Live Demo
|
| 273 |
+
- [Switch to web interface]
|
| 274 |
+
|
| 275 |
+
### Slide 7: Results
|
| 276 |
+
- Performance metrics
|
| 277 |
+
- Example outputs
|
| 278 |
+
|
| 279 |
+
### Slide 8: Impact
|
| 280 |
+
- Efficiency gains
|
| 281 |
+
- Cost savings
|
| 282 |
+
- Improved outcomes
|
| 283 |
+
|
| 284 |
+
### Slide 9: Future Work
|
| 285 |
+
- Potential enhancements
|
| 286 |
+
- Scalability considerations
|
| 287 |
+
|
| 288 |
+
### Slide 10: Thank You
|
| 289 |
+
- Team members
|
| 290 |
+
- Questions?
|
| 291 |
+
|
| 292 |
+
---
|
| 293 |
+
|
| 294 |
+
## 🎬 Demo Script Quick Reference
|
| 295 |
+
|
| 296 |
+
```
|
| 297 |
+
1. MEDICAL QUERY
|
| 298 |
+
→ "I have a rash on my hands. Is there anything stronger than aquaphor?"
|
| 299 |
+
→ Show: Classification, confidence, retrieved documents
|
| 300 |
+
|
| 301 |
+
2. ADMIN QUERY
|
| 302 |
+
→ "Can I get an appointment next month?"
|
| 303 |
+
→ Show: Admin classification, no retrieval
|
| 304 |
+
|
| 305 |
+
3. URGENT QUERY
|
| 306 |
+
→ "worst headache of my life with fever and stiff neck"
|
| 307 |
+
→ Show: High confidence, relevant results
|
| 308 |
+
|
| 309 |
+
4. SETTINGS
|
| 310 |
+
→ Adjust number of results
|
| 311 |
+
→ Toggle reranker
|
| 312 |
+
→ Show both view modes
|
| 313 |
+
```
|
| 314 |
+
|
| 315 |
+
---
|
| 316 |
+
|
| 317 |
+
## ✅ Pre-Demo Checklist
|
| 318 |
+
|
| 319 |
+
- [ ] Web UI is running on http://127.0.0.1:7860
|
| 320 |
+
- [ ] All models loaded successfully
|
| 321 |
+
- [ ] Test queries work correctly
|
| 322 |
+
- [ ] Internet connection stable
|
| 323 |
+
- [ ] Screen sharing setup tested
|
| 324 |
+
- [ ] Backup browser tab open
|
| 325 |
+
- [ ] Documentation files ready
|
| 326 |
+
- [ ] Team roles assigned
|
| 327 |
+
- [ ] Timer set for demo sections
|
| 328 |
+
- [ ] Confidence level: HIGH! 🚀
|
| 329 |
+
|
| 330 |
+
---
|
| 331 |
+
|
| 332 |
+
## 🎓 Presentation Day Affirmations
|
| 333 |
+
|
| 334 |
+
"We've built something awesome!"
|
| 335 |
+
"Our system works reliably!"
|
| 336 |
+
"We understand every component!"
|
| 337 |
+
"We can explain this clearly!"
|
| 338 |
+
"We're ready for any question!"
|
| 339 |
+
|
| 340 |
+
**Good luck, team! You've got this! 🌟**
|
| 341 |
+
|
| 342 |
+
---
|
| 343 |
+
|
| 344 |
+
*Prepared by the HealthBot Team*
|
| 345 |
+
*Feel free to customize this script for your specific presentation requirements*
|
QUICKSTART.md
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 Quick Start Guide - Medical Q&A Bot Web UI
|
| 2 |
+
|
| 3 |
+
This guide will help you get the web interface up and running quickly!
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
- Python 3.8 or higher
|
| 8 |
+
- Git (already done since you have the repo)
|
| 9 |
+
- Virtual environment (recommended)
|
| 10 |
+
|
| 11 |
+
## Step-by-Step Setup
|
| 12 |
+
|
| 13 |
+
### 1️⃣ Navigate to the Project Directory
|
| 14 |
+
|
| 15 |
+
```powershell
|
| 16 |
+
cd "c:\Users\Tarak Jha\OneDrive - Coast to Coast Logistics\Desktop\HEALTHBOT\health-query-classifier"
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
### 2️⃣ Create and Activate Virtual Environment (Recommended)
|
| 20 |
+
|
| 21 |
+
```powershell
|
| 22 |
+
# Create virtual environment
|
| 23 |
+
python -m venv .venv
|
| 24 |
+
|
| 25 |
+
# Activate it (Windows PowerShell)
|
| 26 |
+
.venv\Scripts\Activate.ps1
|
| 27 |
+
|
| 28 |
+
# If you get an execution policy error, run this first:
|
| 29 |
+
# Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
### 3️⃣ Install Dependencies
|
| 33 |
+
|
| 34 |
+
```powershell
|
| 35 |
+
pip install -r requirements.txt
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
This will install:
|
| 39 |
+
- All existing dependencies (PyTorch, sentence-transformers, etc.)
|
| 40 |
+
- **Gradio** - For the web UI (recommended)
|
| 41 |
+
- **Streamlit** - Alternative web UI framework
|
| 42 |
+
|
| 43 |
+
### 4️⃣ Prepare Data (If Not Already Done)
|
| 44 |
+
|
| 45 |
+
```powershell
|
| 46 |
+
python -m adapters.build_corpora
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
This creates the necessary corpus files from PubMed and Miriad databases.
|
| 50 |
+
|
| 51 |
+
### 5️⃣ Launch the Web UI
|
| 52 |
+
|
| 53 |
+
**Option A: Gradio (Recommended)**
|
| 54 |
+
```powershell
|
| 55 |
+
python app.py
|
| 56 |
+
```
|
| 57 |
+
Then open: http://127.0.0.1:7860
|
| 58 |
+
|
| 59 |
+
**Option B: Streamlit (Alternative)**
|
| 60 |
+
```powershell
|
| 61 |
+
streamlit run app_streamlit.py
|
| 62 |
+
```
|
| 63 |
+
Then open: http://localhost:8501
|
| 64 |
+
|
| 65 |
+
## 🎯 Choose Your UI
|
| 66 |
+
|
| 67 |
+
### Gradio (`app.py`)
|
| 68 |
+
✅ Clean, modern interface
|
| 69 |
+
✅ Dual-view (Formatted HTML + JSON)
|
| 70 |
+
✅ Easy to share (can create public links)
|
| 71 |
+
✅ Automatic API generation
|
| 72 |
+
✅ Great for ML demos
|
| 73 |
+
|
| 74 |
+
### Streamlit (`app_streamlit.py`)
|
| 75 |
+
✅ More interactive and customizable
|
| 76 |
+
✅ Sidebar with settings
|
| 77 |
+
✅ Real-time updates
|
| 78 |
+
✅ Better for data science apps
|
| 79 |
+
✅ More widgets and components
|
| 80 |
+
|
| 81 |
+
**We recommend starting with Gradio!** It's simpler and looks very professional.
|
| 82 |
+
|
| 83 |
+
## 🧪 Test with Example Queries
|
| 84 |
+
|
| 85 |
+
Try these queries to see the system in action:
|
| 86 |
+
|
| 87 |
+
1. **Medical Query:**
|
| 88 |
+
> "I'm having a really bad rash on my hands. I'm pretty sure it's my eczema flaring up. Is there anything stronger than aquaphor I can use on it?"
|
| 89 |
+
|
| 90 |
+
2. **Medical Emergency:**
|
| 91 |
+
> "worst headache of my life with fever and stiff neck"
|
| 92 |
+
|
| 93 |
+
3. **Vaccine Question:**
|
| 94 |
+
> "I'm traveling to South America soon. Do I need to get any vaccines before I go?"
|
| 95 |
+
|
| 96 |
+
4. **Administrative Query:**
|
| 97 |
+
> "Hey is there any way I can get an appointment in the next month?"
|
| 98 |
+
|
| 99 |
+
## 🎨 UI Features
|
| 100 |
+
|
| 101 |
+
### Classification
|
| 102 |
+
- Shows whether query is medical, administrative, or other
|
| 103 |
+
- Displays confidence scores for each category
|
| 104 |
+
- Visual progress bars or charts
|
| 105 |
+
|
| 106 |
+
### Document Retrieval (Medical Queries Only)
|
| 107 |
+
- Retrieves top N relevant documents
|
| 108 |
+
- Shows BM25, Dense, and RRF scores
|
| 109 |
+
- Displays document title, text preview, and metadata
|
| 110 |
+
- Toggle between formatted view and raw JSON
|
| 111 |
+
|
| 112 |
+
### Settings
|
| 113 |
+
- **Number of Results:** 1-50 documents
|
| 114 |
+
- **Use Reranker:** Enable for better accuracy (slower)
|
| 115 |
+
|
| 116 |
+
## 🔧 Troubleshooting
|
| 117 |
+
|
| 118 |
+
### "No module named 'gradio'"
|
| 119 |
+
```powershell
|
| 120 |
+
pip install gradio
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
### "No corpora files found"
|
| 124 |
+
```powershell
|
| 125 |
+
python -m adapters.build_corpora
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### Port Already in Use
|
| 129 |
+
Edit `app.py` and change the port:
|
| 130 |
+
```python
|
| 131 |
+
demo.launch(server_port=8080) # Change 7860 to 8080
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
### Models Not Loading
|
| 135 |
+
Make sure you have your HuggingFace token configured in `env.list`:
|
| 136 |
+
```
|
| 137 |
+
HF_TOKEN="your-huggingface-token"
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
## 📊 Advanced Options
|
| 141 |
+
|
| 142 |
+
### Share Publicly (Gradio)
|
| 143 |
+
Edit `app.py`, line ~255:
|
| 144 |
+
```python
|
| 145 |
+
demo.launch(share=True) # Creates a 72-hour public link
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
### Add Authentication (Gradio)
|
| 149 |
+
```python
|
| 150 |
+
demo.launch(auth=("username", "password"))
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
### Change Theme (Streamlit)
|
| 154 |
+
Create `.streamlit/config.toml`:
|
| 155 |
+
```toml
|
| 156 |
+
[theme]
|
| 157 |
+
primaryColor = "#667eea"
|
| 158 |
+
backgroundColor = "#ffffff"
|
| 159 |
+
secondaryBackgroundColor = "#f0f2f6"
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
## 📱 Accessing from Other Devices
|
| 163 |
+
|
| 164 |
+
To access from other devices on your network:
|
| 165 |
+
|
| 166 |
+
1. Find your IP address:
|
| 167 |
+
```powershell
|
| 168 |
+
ipconfig
|
| 169 |
+
```
|
| 170 |
+
Look for "IPv4 Address" (e.g., 192.168.1.100)
|
| 171 |
+
|
| 172 |
+
2. Edit `app.py`:
|
| 173 |
+
```python
|
| 174 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
3. Access from other device:
|
| 178 |
+
```
|
| 179 |
+
http://192.168.1.100:7860
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
## 🎓 For Your Group Presentation
|
| 183 |
+
|
| 184 |
+
### Demo Tips:
|
| 185 |
+
1. Start with the interface loaded beforehand
|
| 186 |
+
2. Have example queries ready
|
| 187 |
+
3. Show both medical and administrative classification
|
| 188 |
+
4. Demonstrate the reranker toggle
|
| 189 |
+
5. Show both formatted and JSON views
|
| 190 |
+
6. Explain the confidence scores
|
| 191 |
+
|
| 192 |
+
### Screenshots to Take:
|
| 193 |
+
- Main interface
|
| 194 |
+
- Classification results
|
| 195 |
+
- Document retrieval results
|
| 196 |
+
- Settings panel
|
| 197 |
+
- Example queries
|
| 198 |
+
|
| 199 |
+
### Key Points to Mention:
|
| 200 |
+
- Built with modern Python web frameworks
|
| 201 |
+
- Real-time classification and retrieval
|
| 202 |
+
- Hybrid search (BM25 + Dense embeddings)
|
| 203 |
+
- Optional reranking for accuracy
|
| 204 |
+
- Clean, professional interface
|
| 205 |
+
|
| 206 |
+
## 📝 Next Steps
|
| 207 |
+
|
| 208 |
+
Once you're comfortable with the UI, you can:
|
| 209 |
+
- Customize the styling (CSS in `app.py`)
|
| 210 |
+
- Add more example queries
|
| 211 |
+
- Integrate with other systems via the API
|
| 212 |
+
- Deploy to cloud (Hugging Face Spaces, AWS, etc.)
|
| 213 |
+
|
| 214 |
+
## 🤝 Team Credits
|
| 215 |
+
|
| 216 |
+
Display proudly on the interface:
|
| 217 |
+
- David Gray
|
| 218 |
+
- Tarak Jha
|
| 219 |
+
- Sravani Segireddy
|
| 220 |
+
- Riley Millikan
|
| 221 |
+
- Kent R. Spillner
|
| 222 |
+
|
| 223 |
+
---
|
| 224 |
+
|
| 225 |
+
**Need help?** Check the full documentation in `UI_README.md` or ask your team members!
|
README.md
CHANGED
|
@@ -1,12 +1,81 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Medical_Document_Retrieval
|
| 3 |
+
app_file: app_retrieval_cached.py
|
| 4 |
+
sdk: gradio
|
| 5 |
+
sdk_version: 6.0.2
|
| 6 |
+
---
|
| 7 |
+
# Health Query Classifier & Research Retriever
|
| 8 |
+
|
| 9 |
+
## Team Members
|
| 10 |
+
* **David Gray**
|
| 11 |
+
* **Tarak Jha**
|
| 12 |
+
* **Sravani Segireddy**
|
| 13 |
+
* **Riley Millikan**
|
| 14 |
+
* **Kent R. Spillner**
|
| 15 |
+
|
| 16 |
+
## Project Description
|
| 17 |
+
This project is a classifier that triages patient queries. If a query is identified as medical, the system retrieves relevant research and presents it to the user.
|
| 18 |
+
|
| 19 |
+
## Workflow
|
| 20 |
+
The system operates in two main stages to optimize patient care and provider efficiency:
|
| 21 |
+
|
| 22 |
+
1. **Classification (Triage)**:
|
| 23 |
+
The tool analyzes the user's input to determine if it is a medical query (requiring clinical attention) or an administrative query (scheduling, billing, etc.).
|
| 24 |
+
|
| 25 |
+
2. **Research Retrieval**:
|
| 26 |
+
If the query is classified as medical, the system searches through indexed medical databases (like PubMed and Miriad) to retrieve relevant research articles and Q/A pairs. This empowers the patient with trustworthy information and provides the doctor with context.
|
| 27 |
+
|
| 28 |
+
### Training Script
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
python3 -m classifier.train
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Running the System Locally
|
| 35 |
+
|
| 36 |
+
### Prerequisites
|
| 37 |
+
* Git
|
| 38 |
+
* Python 3
|
| 39 |
+
|
| 40 |
+
### Setup & Configuration
|
| 41 |
+
|
| 42 |
+
1. **Clone the repository**
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
git clone https://github.com/davidgraymi/health-query-classifier.git
|
| 46 |
+
cd health-query-classifier
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
2. **Configure environment variables**
|
| 50 |
+
|
| 51 |
+
This project uses an `env.list` file for configuration. Create this file in the root directory.
|
| 52 |
+
```ini
|
| 53 |
+
# env.list
|
| 54 |
+
HF_TOKEN="your-huggingface-token"
|
| 55 |
+
```
|
| 56 |
+
* **HF_TOKEN**: Access token can be generated via [huggingface](https://huggingface.co/settings/tokens). The token must have read permissions.
|
| 57 |
+
|
| 58 |
+
3. **Create a python virtual environment**
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
python3 -m venv .venv
|
| 62 |
+
source .venv/bin/activate
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
4. **Install dependencies**
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
pip install -r requirements.txt
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
### Data Setup
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
python3 adapters/build_corpora.py
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
### Execution
|
| 78 |
+
|
| 79 |
+
```bash
|
| 80 |
+
python3 main.py
|
| 81 |
+
```
|
UI_GUIDE.md
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Summary: Available UI Options for Medical Q&A Bot
|
| 2 |
+
|
| 3 |
+
## 🎯 Three Versions Created
|
| 4 |
+
|
| 5 |
+
### 1. **app_demo.py** ⚡ RECOMMENDED FOR DEMOS
|
| 6 |
+
**Port:** 7863
|
| 7 |
+
**Speed:** Instant (<1 second)
|
| 8 |
+
**Features:**
|
| 9 |
+
- ✅ Real-time classification (medical vs administrative)
|
| 10 |
+
- ✅ Confidence scores with visualization
|
| 11 |
+
- ✅ Action recommendations
|
| 12 |
+
- ✅ Uses your group's trained models
|
| 13 |
+
- ❌ No document retrieval (for speed)
|
| 14 |
+
|
| 15 |
+
**Best for:**
|
| 16 |
+
- Class presentations
|
| 17 |
+
- Quick demonstrations
|
| 18 |
+
- Testing classification accuracy
|
| 19 |
+
- When time is limited
|
| 20 |
+
|
| 21 |
+
**Run:** `python app_demo.py`
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
### 2. **app_full.py** 🔬 COMPLETE SYSTEM
|
| 26 |
+
**Port:** 7864
|
| 27 |
+
**Speed:** First query: 6-10 minutes, subsequent: 2-5 seconds
|
| 28 |
+
**Features:**
|
| 29 |
+
- ✅ Real-time classification
|
| 30 |
+
- ✅ **Full document retrieval** from PubMed & Miriad
|
| 31 |
+
- ✅ BM25 + Dense search + RRF fusion
|
| 32 |
+
- ✅ Optional cross-encoder reranking
|
| 33 |
+
- ⚠️ Very slow first initialization
|
| 34 |
+
|
| 35 |
+
**Best for:**
|
| 36 |
+
- Showing full system capabilities
|
| 37 |
+
- When you have 10+ minutes to wait
|
| 38 |
+
- Detailed technical demonstrations
|
| 39 |
+
- Proving retrieval works
|
| 40 |
+
|
| 41 |
+
**Run:** `python app_full.py`
|
| 42 |
+
|
| 43 |
+
**⚠️ WARNING:** First medical query takes 6-10 minutes because:
|
| 44 |
+
- Loads ~200MB+ of medical corpus data
|
| 45 |
+
- Builds BM25 keyword index
|
| 46 |
+
- Generates embeddings for ALL documents (this is the slow part)
|
| 47 |
+
- Builds FAISS vector index
|
| 48 |
+
|
| 49 |
+
---
|
| 50 |
+
|
| 51 |
+
### 3. **app.py** / **app_safe.py** / **app_lightweight.py** 🔧 EXPERIMENTAL
|
| 52 |
+
These were intermediate versions created while troubleshooting.
|
| 53 |
+
**Not recommended for use.**
|
| 54 |
+
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
+
## 🎬 Recommendation for Your Group Presentation
|
| 58 |
+
|
| 59 |
+
### Strategy 1: Fast Demo (5 minutes)
|
| 60 |
+
Use **`app_demo.py`** only:
|
| 61 |
+
1. Show classification working instantly
|
| 62 |
+
2. Test medical vs administrative queries
|
| 63 |
+
3. Highlight confidence scores
|
| 64 |
+
4. Explain that retrieval is available but disabled for demo speed
|
| 65 |
+
5. Show the codebase that supports retrieval (team/candidates.py)
|
| 66 |
+
|
| 67 |
+
**Advantage:** Reliable, professional, no waiting
|
| 68 |
+
|
| 69 |
+
---
|
| 70 |
+
|
| 71 |
+
### Strategy 2: Split Demo (15+ minutes)
|
| 72 |
+
Use BOTH versions:
|
| 73 |
+
|
| 74 |
+
**Part 1:** Use `app_demo.py` for quick classification demos (5 min)
|
| 75 |
+
- Show multiple queries rapidly
|
| 76 |
+
- Demonstrate accuracy
|
| 77 |
+
|
| 78 |
+
**Part 2:** Switch to `app_full.py` that you pre-initialized (10 min)
|
| 79 |
+
- **Before presentation:** Run `app_full.py` and make ONE medical query to initialize
|
| 80 |
+
- Wait the 10 minutes for it to build indexes
|
| 81 |
+
- Keep it running
|
| 82 |
+
- During presentation: Show actual document retrieval working fast
|
| 83 |
+
|
| 84 |
+
**Advantage:** Shows both speed AND capabilities
|
| 85 |
+
|
| 86 |
+
---
|
| 87 |
+
|
| 88 |
+
### Strategy 3: Video Backup
|
| 89 |
+
1. Use `app_demo.py` for live demo
|
| 90 |
+
2. Record a video of `app_full.py` working with retrieval
|
| 91 |
+
3. Show video during presentation if needed
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## 📊 Technical Details to Mention
|
| 96 |
+
|
| 97 |
+
### Your Group's Implementation:
|
| 98 |
+
- **Classification Model:** Fine-tuned sentence-transformers (embeddinggemma-300m-medical)
|
| 99 |
+
- **Hybrid Retrieval:** BM25 (sparse) + Dense embeddings (semantic)
|
| 100 |
+
- **Fusion Algorithm:** Reciprocal Rank Fusion (RRF)
|
| 101 |
+
- **Data Sources:** PubMed Medical Q&A + Miriad corpus
|
| 102 |
+
- **Optional Enhancement:** Cross-encoder reranker for accuracy
|
| 103 |
+
|
| 104 |
+
### Why Retrieval is Slow:
|
| 105 |
+
- Real ML systems need to index large datasets
|
| 106 |
+
- Your corpus has thousands of medical documents
|
| 107 |
+
- CPU-only inference (no GPU acceleration available)
|
| 108 |
+
- This is a REAL implementation, not a toy demo
|
| 109 |
+
|
| 110 |
+
### Production Solutions:
|
| 111 |
+
- Pre-build and save indexes (don't rebuild each time)
|
| 112 |
+
- Use GPU for faster embedding
|
| 113 |
+
- Implement caching
|
| 114 |
+
- Deploy on cloud with more resources
|
| 115 |
+
|
| 116 |
+
---
|
| 117 |
+
|
| 118 |
+
## 💡 Demo Script Suggestions
|
| 119 |
+
|
| 120 |
+
### Opening (30 seconds):
|
| 121 |
+
"We built an AI system that automatically classifies patient queries and retrieves relevant medical research. Let me show you how it works..."
|
| 122 |
+
|
| 123 |
+
### Classification Demo (2-3 minutes):
|
| 124 |
+
"First, our classification system determines if a query is medical or administrative..."
|
| 125 |
+
[Use app_demo.py, try 3-4 different queries]
|
| 126 |
+
|
| 127 |
+
### Technical Explanation (2 minutes):
|
| 128 |
+
"Under the hood, we use:
|
| 129 |
+
- A fine-tuned 300-million parameter transformer model
|
| 130 |
+
- Hybrid search combining keyword matching and semantic similarity
|
| 131 |
+
- Reciprocal Rank Fusion to combine results
|
| 132 |
+
- Medical corpora from PubMed and Miriad databases"
|
| 133 |
+
|
| 134 |
+
### Show Retrieval (Optional, if pre-initialized):
|
| 135 |
+
"Now let me show you actual document retrieval..."
|
| 136 |
+
[Use app_full.py if you pre-initialized it]
|
| 137 |
+
|
| 138 |
+
### Closing (30 seconds):
|
| 139 |
+
"This demonstrates how AI can improve healthcare triage, reduce response times, and provide evidence-based information to both patients and providers."
|
| 140 |
+
|
| 141 |
+
---
|
| 142 |
+
|
| 143 |
+
## 🚀 Quick Start Commands
|
| 144 |
+
|
| 145 |
+
```powershell
|
| 146 |
+
# For demos and presentations
|
| 147 |
+
python app_demo.py
|
| 148 |
+
# Access at: http://127.0.0.1:7863
|
| 149 |
+
|
| 150 |
+
# For full system (wait 10 minutes after first query)
|
| 151 |
+
python app_full.py
|
| 152 |
+
# Access at: http://127.0.0.1:7864
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
---
|
| 156 |
+
|
| 157 |
+
## ✅ What You Successfully Built
|
| 158 |
+
|
| 159 |
+
1. ✅ Working web UI with professional design
|
| 160 |
+
2. ✅ Real-time classification using your trained model
|
| 161 |
+
3. ✅ Full retrieval system integrated
|
| 162 |
+
4. ✅ Two versions: fast demo + complete system
|
| 163 |
+
5. ✅ Comprehensive documentation
|
| 164 |
+
6. ✅ Example queries
|
| 165 |
+
7. ✅ Clear visualization of results
|
| 166 |
+
|
| 167 |
+
**You have everything you need for a successful presentation!**
|
| 168 |
+
|
| 169 |
+
---
|
| 170 |
+
|
| 171 |
+
## 🎯 Final Recommendation
|
| 172 |
+
|
| 173 |
+
**For your presentation, use `app_demo.py`**
|
| 174 |
+
|
| 175 |
+
It shows your ML work instantly and professionally. You can explain:
|
| 176 |
+
- "The classification happens in real-time"
|
| 177 |
+
- "The full system includes retrieval which we can show separately"
|
| 178 |
+
- "This demonstrates the core AI capability"
|
| 179 |
+
|
| 180 |
+
If anyone asks about retrieval, you can:
|
| 181 |
+
- Show the code in `team/candidates.py`
|
| 182 |
+
- Explain the hybrid search architecture
|
| 183 |
+
- Mention it's fully implemented but slow due to index building
|
| 184 |
+
|
| 185 |
+
**This is the smart approach for a live demo!**
|
UI_IMPLEMENTATION.md
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Medical Q&A Bot - UI Implementation Summary
|
| 2 |
+
|
| 3 |
+
## 📦 What Was Added
|
| 4 |
+
|
| 5 |
+
### New Files Created:
|
| 6 |
+
|
| 7 |
+
1. **`app.py`** - Main Gradio web interface (RECOMMENDED)
|
| 8 |
+
- Clean, modern UI with gradient header
|
| 9 |
+
- Dual-view mode (formatted HTML + JSON)
|
| 10 |
+
- Real-time classification and retrieval
|
| 11 |
+
- Example queries built-in
|
| 12 |
+
- Automatic API generation
|
| 13 |
+
|
| 14 |
+
2. **`app_streamlit.py`** - Alternative Streamlit interface
|
| 15 |
+
- Interactive sidebar with settings
|
| 16 |
+
- Card-based result display
|
| 17 |
+
- Progress bars for confidence scores
|
| 18 |
+
- More customizable styling options
|
| 19 |
+
|
| 20 |
+
3. **`QUICKSTART.md`** - Step-by-step setup guide
|
| 21 |
+
- Installation instructions
|
| 22 |
+
- How to launch both UIs
|
| 23 |
+
- Troubleshooting tips
|
| 24 |
+
- Demo tips for presentations
|
| 25 |
+
|
| 26 |
+
4. **`UI_README.md`** - Comprehensive documentation
|
| 27 |
+
- Feature descriptions
|
| 28 |
+
- Configuration options
|
| 29 |
+
- Advanced usage
|
| 30 |
+
- API information
|
| 31 |
+
|
| 32 |
+
5. **`setup_ui.py`** - Automated setup script
|
| 33 |
+
- Installs all dependencies
|
| 34 |
+
- Checks for required data files
|
| 35 |
+
- Verifies setup completeness
|
| 36 |
+
|
| 37 |
+
### Modified Files:
|
| 38 |
+
|
| 39 |
+
1. **`requirements.txt`** - Added:
|
| 40 |
+
- `gradio` - Main UI framework
|
| 41 |
+
- `streamlit` - Alternative UI framework
|
| 42 |
+
|
| 43 |
+
## 🎯 Key Features
|
| 44 |
+
|
| 45 |
+
### Classification Display
|
| 46 |
+
- ✅ Shows query type (Medical/Administrative/Other)
|
| 47 |
+
- ✅ Confidence scores with visual indicators
|
| 48 |
+
- ✅ Color-coded results
|
| 49 |
+
|
| 50 |
+
### Document Retrieval
|
| 51 |
+
- ✅ Retrieves top N relevant documents (1-50)
|
| 52 |
+
- ✅ Shows BM25, Dense, and RRF scores
|
| 53 |
+
- ✅ Displays document preview with full metadata
|
| 54 |
+
- ✅ Optional reranker for better accuracy
|
| 55 |
+
|
| 56 |
+
### User Experience
|
| 57 |
+
- ✅ Clean, professional design
|
| 58 |
+
- ✅ Example queries for easy testing
|
| 59 |
+
- ✅ Formatted and raw JSON views
|
| 60 |
+
- ✅ Responsive layout
|
| 61 |
+
- ✅ Real-time processing
|
| 62 |
+
|
| 63 |
+
## 🚀 Quick Start (TL;DR)
|
| 64 |
+
|
| 65 |
+
```powershell
|
| 66 |
+
# 1. Install dependencies
|
| 67 |
+
pip install -r requirements.txt
|
| 68 |
+
|
| 69 |
+
# 2. Build data (if needed)
|
| 70 |
+
python -m adapters.build_corpora
|
| 71 |
+
|
| 72 |
+
# 3. Launch UI (choose one)
|
| 73 |
+
python app.py # Gradio (recommended)
|
| 74 |
+
streamlit run app_streamlit.py # Streamlit (alternative)
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## 🌟 Which UI to Use?
|
| 78 |
+
|
| 79 |
+
### Use Gradio (`app.py`) if you want:
|
| 80 |
+
- Quick setup and deployment
|
| 81 |
+
- Clean, minimal interface
|
| 82 |
+
- Easy sharing (public links)
|
| 83 |
+
- Automatic REST API
|
| 84 |
+
- Better for demos and presentations
|
| 85 |
+
|
| 86 |
+
### Use Streamlit (`app_streamlit.py`) if you want:
|
| 87 |
+
- More interactive controls
|
| 88 |
+
- Sidebar with settings
|
| 89 |
+
- More customization options
|
| 90 |
+
- Data-science focused interface
|
| 91 |
+
|
| 92 |
+
**Recommendation:** Start with **Gradio** - it's simpler and looks very professional!
|
| 93 |
+
|
| 94 |
+
## 📊 How It Works
|
| 95 |
+
|
| 96 |
+
```
|
| 97 |
+
User Input → Gradio/Streamlit Interface
|
| 98 |
+
↓
|
| 99 |
+
Classifier (classify query as medical/admin/other)
|
| 100 |
+
↓
|
| 101 |
+
If Medical → Retriever (BM25 + Dense + RRF)
|
| 102 |
+
↓
|
| 103 |
+
Optional Reranker (cross-encoder)
|
| 104 |
+
↓
|
| 105 |
+
Display Results (formatted cards + JSON)
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## 🎓 For Your Presentation
|
| 109 |
+
|
| 110 |
+
### Demo Flow:
|
| 111 |
+
1. **Open the interface** - Show the clean design
|
| 112 |
+
2. **Enter a medical query** - e.g., "I have a rash on my hands"
|
| 113 |
+
3. **Show classification** - Highlight confidence scores
|
| 114 |
+
4. **Display results** - Show retrieved medical documents
|
| 115 |
+
5. **Try administrative query** - Show different handling
|
| 116 |
+
6. **Toggle settings** - Demonstrate reranker and result count
|
| 117 |
+
7. **Show JSON view** - For technical audience
|
| 118 |
+
|
| 119 |
+
### Key Talking Points:
|
| 120 |
+
- "We built a user-friendly web interface using Gradio"
|
| 121 |
+
- "The system classifies queries in real-time with confidence scores"
|
| 122 |
+
- "For medical queries, it retrieves relevant research from PubMed and Miriad"
|
| 123 |
+
- "Uses hybrid search combining BM25 and dense embeddings"
|
| 124 |
+
- "Optional reranker for improved accuracy"
|
| 125 |
+
|
| 126 |
+
## 🎨 Customization Options
|
| 127 |
+
|
| 128 |
+
### Change Theme/Colors
|
| 129 |
+
Edit the CSS in `app.py`:
|
| 130 |
+
```python
|
| 131 |
+
custom_css = """
|
| 132 |
+
.header {
|
| 133 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 134 |
+
}
|
| 135 |
+
"""
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
### Add More Example Queries
|
| 139 |
+
Edit the `examples` list in `app.py`:
|
| 140 |
+
```python
|
| 141 |
+
gr.Examples(
|
| 142 |
+
examples=[
|
| 143 |
+
["Your custom query here"],
|
| 144 |
+
],
|
| 145 |
+
inputs=query_input,
|
| 146 |
+
)
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
### Change Port
|
| 150 |
+
```python
|
| 151 |
+
demo.launch(server_port=8080) # Default is 7860
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
### Enable Public Sharing
|
| 155 |
+
```python
|
| 156 |
+
demo.launch(share=True) # Creates 72-hour public link
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
## 🔧 Technical Details
|
| 160 |
+
|
| 161 |
+
### Gradio Version
|
| 162 |
+
- Framework: Gradio 4.x
|
| 163 |
+
- Features: Blocks API, custom CSS, examples, tabs
|
| 164 |
+
- Auto-generates REST API at `/docs`
|
| 165 |
+
|
| 166 |
+
### Streamlit Version
|
| 167 |
+
- Framework: Streamlit
|
| 168 |
+
- Features: Sidebar, caching, progress bars, metrics
|
| 169 |
+
- More suitable for data exploration
|
| 170 |
+
|
| 171 |
+
### Integration
|
| 172 |
+
- Uses existing classifier and retriever modules
|
| 173 |
+
- No changes to core logic
|
| 174 |
+
- Models loaded once and cached
|
| 175 |
+
- Async processing for better UX
|
| 176 |
+
|
| 177 |
+
## 📁 File Structure
|
| 178 |
+
|
| 179 |
+
```
|
| 180 |
+
health-query-classifier/
|
| 181 |
+
├── app.py # ← Main Gradio UI
|
| 182 |
+
├── app_streamlit.py # ← Alternative Streamlit UI
|
| 183 |
+
├── setup_ui.py # ← Setup script
|
| 184 |
+
├── QUICKSTART.md # ← Quick start guide
|
| 185 |
+
├── UI_README.md # ← Detailed documentation
|
| 186 |
+
├── requirements.txt # ← Updated with gradio/streamlit
|
| 187 |
+
├── classifier/
|
| 188 |
+
│ ├── infer.py # Used by UI
|
| 189 |
+
│ └── ...
|
| 190 |
+
├── retriever/
|
| 191 |
+
│ └── ...
|
| 192 |
+
└── team/
|
| 193 |
+
├── candidates.py # Used by UI
|
| 194 |
+
└── ...
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
## 🐛 Common Issues & Solutions
|
| 198 |
+
|
| 199 |
+
### "Module not found: gradio"
|
| 200 |
+
```powershell
|
| 201 |
+
pip install gradio
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
### "No corpora files found"
|
| 205 |
+
```powershell
|
| 206 |
+
python -m adapters.build_corpora
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
### Models take long to load
|
| 210 |
+
- This is normal on first run
|
| 211 |
+
- Models are cached after initial load
|
| 212 |
+
- Consider using smaller models for faster demo
|
| 213 |
+
|
| 214 |
+
### Port already in use
|
| 215 |
+
- Change port in `app.py` (line ~255)
|
| 216 |
+
- Or kill the process using that port
|
| 217 |
+
|
| 218 |
+
## 🌐 Deployment Options
|
| 219 |
+
|
| 220 |
+
### Local (Current Setup)
|
| 221 |
+
- Best for development and demos
|
| 222 |
+
- Access via localhost
|
| 223 |
+
|
| 224 |
+
### Hugging Face Spaces (Free)
|
| 225 |
+
- Free hosting for Gradio apps
|
| 226 |
+
- Easy to deploy
|
| 227 |
+
- Public URL
|
| 228 |
+
|
| 229 |
+
### Cloud Platforms
|
| 230 |
+
- AWS, Google Cloud, Azure
|
| 231 |
+
- More control and scalability
|
| 232 |
+
- Requires more setup
|
| 233 |
+
|
| 234 |
+
## 📈 Future Enhancements
|
| 235 |
+
|
| 236 |
+
Potential additions for future development:
|
| 237 |
+
- User authentication
|
| 238 |
+
- Query history
|
| 239 |
+
- Bookmarking results
|
| 240 |
+
- Export to PDF
|
| 241 |
+
- Multi-language support
|
| 242 |
+
- Voice input
|
| 243 |
+
- Mobile app
|
| 244 |
+
- Analytics dashboard
|
| 245 |
+
|
| 246 |
+
## 👥 Team Contributions
|
| 247 |
+
|
| 248 |
+
This UI implementation demonstrates:
|
| 249 |
+
- Full-stack development skills
|
| 250 |
+
- ML model integration
|
| 251 |
+
- User experience design
|
| 252 |
+
- Modern web frameworks
|
| 253 |
+
- Professional documentation
|
| 254 |
+
|
| 255 |
+
Perfect for showcasing in your group project presentation!
|
| 256 |
+
|
| 257 |
+
## 📞 Support
|
| 258 |
+
|
| 259 |
+
For issues or questions:
|
| 260 |
+
1. Check `QUICKSTART.md` for setup issues
|
| 261 |
+
2. Check `UI_README.md` for feature documentation
|
| 262 |
+
3. Review error messages carefully
|
| 263 |
+
4. Contact team members
|
| 264 |
+
|
| 265 |
+
---
|
| 266 |
+
|
| 267 |
+
## ✅ Ready to Present!
|
| 268 |
+
|
| 269 |
+
Your medical Q&A bot now has a professional web interface that:
|
| 270 |
+
- ✅ Looks modern and clean
|
| 271 |
+
- ✅ Is easy to use
|
| 272 |
+
- ✅ Demonstrates your ML capabilities
|
| 273 |
+
- ✅ Provides clear results
|
| 274 |
+
- ✅ Is well-documented
|
| 275 |
+
- ✅ Can be easily deployed
|
| 276 |
+
|
| 277 |
+
**Great work, team!** 🎉
|
| 278 |
+
|
| 279 |
+
---
|
| 280 |
+
|
| 281 |
+
*Created by: Tarak Jha, with contributions from the entire team*
|
| 282 |
+
*Team: David Gray • Tarak Jha • Sravani Segireddy • Riley Millikan • Kent R. Spillner*
|
UI_README.md
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Medical Q&A Bot - Web UI
|
| 2 |
+
|
| 3 |
+
This is a user-friendly web interface for the Health Query Classifier & Research Retriever system.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
✨ **Clean, Modern Interface** - Built with Gradio for an intuitive user experience
|
| 8 |
+
|
| 9 |
+
🎯 **Query Classification** - Automatically triages queries as medical or administrative with confidence scores
|
| 10 |
+
|
| 11 |
+
📚 **Intelligent Retrieval** - Retrieves relevant medical research from PubMed and Miriad databases
|
| 12 |
+
|
| 13 |
+
🔍 **Dual View Modes** - View results in formatted HTML or raw JSON
|
| 14 |
+
|
| 15 |
+
⚙️ **Customizable Settings** - Adjust number of results and toggle reranker for better accuracy
|
| 16 |
+
|
| 17 |
+
## Quick Start
|
| 18 |
+
|
| 19 |
+
### 1. Install Dependencies
|
| 20 |
+
|
| 21 |
+
Make sure you have the updated requirements installed:
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
pip install -r requirements.txt
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
This will install Gradio along with all other dependencies.
|
| 28 |
+
|
| 29 |
+
### 2. Prepare Data
|
| 30 |
+
|
| 31 |
+
If you haven't already, build the corpora:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
python -m adapters.build_corpora
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
### 3. Launch the Web UI
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
python app.py
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
The interface will be available at: **http://127.0.0.1:7860**
|
| 44 |
+
|
| 45 |
+
## Using the Interface
|
| 46 |
+
|
| 47 |
+
1. **Enter Your Query** - Type your health-related question in the text box
|
| 48 |
+
2. **Adjust Settings** (optional):
|
| 49 |
+
- **Number of Results**: Control how many documents to retrieve (1-50)
|
| 50 |
+
- **Use Reranker**: Enable for more accurate results (slower)
|
| 51 |
+
3. **Click "Analyze Query"** or press Enter
|
| 52 |
+
4. **View Results**:
|
| 53 |
+
- **Classification**: See how your query was categorized
|
| 54 |
+
- **Formatted View**: Readable cards with document information
|
| 55 |
+
- **JSON View**: Raw data for technical analysis
|
| 56 |
+
|
| 57 |
+
## Example Queries
|
| 58 |
+
|
| 59 |
+
Try these example queries to see the system in action:
|
| 60 |
+
|
| 61 |
+
- "I'm having a really bad rash on my hands. Is there anything stronger than aquaphor I can use?"
|
| 62 |
+
- "I'm traveling to South America soon. Do I need to get any vaccines before I go?"
|
| 63 |
+
- "worst headache of my life with fever and stiff neck"
|
| 64 |
+
- "Hey is there any way I can get an appointment in the next month?"
|
| 65 |
+
|
| 66 |
+
## Configuration Options
|
| 67 |
+
|
| 68 |
+
### Sharing Your Interface
|
| 69 |
+
|
| 70 |
+
To create a public shareable link (72 hours), modify `app.py`:
|
| 71 |
+
|
| 72 |
+
```python
|
| 73 |
+
demo.launch(
|
| 74 |
+
share=True, # Creates a public link
|
| 75 |
+
server_name="127.0.0.1",
|
| 76 |
+
server_port=7860,
|
| 77 |
+
)
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
### Custom Port
|
| 81 |
+
|
| 82 |
+
Change the port if 7860 is already in use:
|
| 83 |
+
|
| 84 |
+
```python
|
| 85 |
+
demo.launch(
|
| 86 |
+
server_port=8080, # Use your preferred port
|
| 87 |
+
)
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### Authentication
|
| 91 |
+
|
| 92 |
+
Add password protection:
|
| 93 |
+
|
| 94 |
+
```python
|
| 95 |
+
demo.launch(
|
| 96 |
+
auth=("username", "password"), # Simple auth
|
| 97 |
+
)
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
## Architecture
|
| 101 |
+
|
| 102 |
+
The UI integrates with your existing codebase:
|
| 103 |
+
|
| 104 |
+
```
|
| 105 |
+
User Query → Gradio Interface → Classifier → Retriever → Results Display
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
- **Frontend**: Gradio (Python-based web framework)
|
| 109 |
+
- **Classification**: Uses your trained classifier model
|
| 110 |
+
- **Retrieval**: Hybrid search (BM25 + Dense embeddings + RRF)
|
| 111 |
+
- **Reranking**: Optional cross-encoder reranker
|
| 112 |
+
|
| 113 |
+
## Troubleshooting
|
| 114 |
+
|
| 115 |
+
### Models Not Loading
|
| 116 |
+
|
| 117 |
+
Ensure you have the classifier checkpoint and data files:
|
| 118 |
+
```bash
|
| 119 |
+
ls -la data/corpora/
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
### Port Already in Use
|
| 123 |
+
|
| 124 |
+
Change the port number in `app.py` or kill the process using port 7860.
|
| 125 |
+
|
| 126 |
+
### Gradio Import Error
|
| 127 |
+
|
| 128 |
+
Make sure Gradio is installed:
|
| 129 |
+
```bash
|
| 130 |
+
pip install gradio
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
### No Medical Documents Found
|
| 134 |
+
|
| 135 |
+
Verify that corpora files exist in `data/corpora/`:
|
| 136 |
+
- `medical_qa.jsonl`
|
| 137 |
+
- `miriad_text.jsonl`
|
| 138 |
+
- `unidoc_qa.jsonl`
|
| 139 |
+
|
| 140 |
+
Run the build script if missing:
|
| 141 |
+
```bash
|
| 142 |
+
python -m adapters.build_corpora
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
## Advanced Features
|
| 146 |
+
|
| 147 |
+
### API Mode
|
| 148 |
+
|
| 149 |
+
Gradio automatically creates a REST API alongside the web UI. Access the API docs at:
|
| 150 |
+
```
|
| 151 |
+
http://127.0.0.1:7860/docs
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
### Embedding the Interface
|
| 155 |
+
|
| 156 |
+
You can embed the Gradio interface in other web applications using iframes:
|
| 157 |
+
|
| 158 |
+
```html
|
| 159 |
+
<iframe src="http://127.0.0.1:7860" width="100%" height="800px"></iframe>
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
## Team
|
| 163 |
+
|
| 164 |
+
- **David Gray**
|
| 165 |
+
- **Tarak Jha**
|
| 166 |
+
- **Sravani Segireddy**
|
| 167 |
+
- **Riley Millikan**
|
| 168 |
+
- **Kent R. Spillner**
|
| 169 |
+
|
| 170 |
+
## License
|
| 171 |
+
|
| 172 |
+
See the main README.md for project license information.
|
UI_SUMMARY.md
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎉 Medical Q&A Bot - UI Implementation Complete!
|
| 2 |
+
|
| 3 |
+
## ✅ What You Now Have
|
| 4 |
+
|
| 5 |
+
### 🖥️ Two Complete Web Interfaces
|
| 6 |
+
|
| 7 |
+
#### 1. **Gradio Interface** (`app.py`) - RECOMMENDED ⭐
|
| 8 |
+
- Clean, modern design with gradient styling
|
| 9 |
+
- Dual-view mode (Formatted + JSON)
|
| 10 |
+
- Built-in example queries
|
| 11 |
+
- Easy to share and deploy
|
| 12 |
+
- Automatic REST API generation
|
| 13 |
+
- Launch with: `python app.py`
|
| 14 |
+
- Access at: http://127.0.0.1:7860
|
| 15 |
+
|
| 16 |
+
#### 2. **Streamlit Interface** (`app_streamlit.py`) - ALTERNATIVE
|
| 17 |
+
- Interactive sidebar with live controls
|
| 18 |
+
- Card-based result display
|
| 19 |
+
- Progress bars and metrics
|
| 20 |
+
- More customization options
|
| 21 |
+
- Launch with: `streamlit run app_streamlit.py`
|
| 22 |
+
- Access at: http://localhost:8501
|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
## 📚 Complete Documentation Suite
|
| 27 |
+
|
| 28 |
+
### Quick Start Guides
|
| 29 |
+
1. **QUICKSTART.md** - Step-by-step setup (5 minutes)
|
| 30 |
+
2. **launch_ui.bat** - Windows batch launcher (double-click to run)
|
| 31 |
+
3. **launch_ui.ps1** - PowerShell launcher (right-click → Run with PowerShell)
|
| 32 |
+
4. **setup_ui.py** - Automated setup script
|
| 33 |
+
|
| 34 |
+
### Comprehensive Documentation
|
| 35 |
+
1. **UI_README.md** - Complete UI feature documentation
|
| 36 |
+
2. **UI_IMPLEMENTATION.md** - Implementation details and summary
|
| 37 |
+
3. **ARCHITECTURE.md** - System architecture with diagrams
|
| 38 |
+
4. **PRESENTATION_SCRIPT.md** - Complete presentation guide with demo script
|
| 39 |
+
|
| 40 |
+
---
|
| 41 |
+
|
| 42 |
+
## 🚀 How to Get Started (3 Easy Steps)
|
| 43 |
+
|
| 44 |
+
### Option 1: PowerShell Launcher (Easiest!)
|
| 45 |
+
```powershell
|
| 46 |
+
# Just double-click or run:
|
| 47 |
+
.\launch_ui.ps1
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### Option 2: Command Line
|
| 51 |
+
```powershell
|
| 52 |
+
# 1. Install dependencies
|
| 53 |
+
pip install -r requirements.txt
|
| 54 |
+
|
| 55 |
+
# 2. Build data (if needed)
|
| 56 |
+
python -m adapters.build_corpora
|
| 57 |
+
|
| 58 |
+
# 3. Launch!
|
| 59 |
+
python app.py
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
### Option 3: Setup Script
|
| 63 |
+
```powershell
|
| 64 |
+
# Run the automated setup
|
| 65 |
+
python setup_ui.py
|
| 66 |
+
|
| 67 |
+
# Then launch
|
| 68 |
+
python app.py
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
---
|
| 72 |
+
|
| 73 |
+
## 🎨 Key Features
|
| 74 |
+
|
| 75 |
+
### Classification
|
| 76 |
+
✅ Automatic query classification (Medical/Administrative/Other)
|
| 77 |
+
✅ Confidence scores for transparency
|
| 78 |
+
✅ Visual indicators and progress bars
|
| 79 |
+
✅ Color-coded results
|
| 80 |
+
|
| 81 |
+
### Retrieval
|
| 82 |
+
✅ Hybrid search (BM25 + Dense + RRF)
|
| 83 |
+
✅ Retrieves from PubMed, Miriad, and UniDoc
|
| 84 |
+
✅ Adjustable number of results (1-50)
|
| 85 |
+
✅ Optional cross-encoder reranking
|
| 86 |
+
✅ Multiple relevance scores per document
|
| 87 |
+
|
| 88 |
+
### User Experience
|
| 89 |
+
✅ Clean, professional interface
|
| 90 |
+
✅ Example queries built-in
|
| 91 |
+
✅ Real-time processing
|
| 92 |
+
✅ Formatted and JSON view modes
|
| 93 |
+
✅ Mobile-responsive design
|
| 94 |
+
✅ Error handling and validation
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
## 📁 New Files Created
|
| 99 |
+
|
| 100 |
+
```
|
| 101 |
+
health-query-classifier/
|
| 102 |
+
├── 🌐 Web Interfaces
|
| 103 |
+
│ ├── app.py ⭐ (Main Gradio UI)
|
| 104 |
+
│ ├── app_streamlit.py (Alternative Streamlit UI)
|
| 105 |
+
│ ├── launch_ui.bat (Windows launcher)
|
| 106 |
+
│ └── launch_ui.ps1 (PowerShell launcher)
|
| 107 |
+
│
|
| 108 |
+
├── 📚 Documentation
|
| 109 |
+
│ ├── QUICKSTART.md (5-minute setup guide)
|
| 110 |
+
│ ├── UI_README.md (Feature documentation)
|
| 111 |
+
│ ├── UI_IMPLEMENTATION.md (Technical summary)
|
| 112 |
+
│ ├── ARCHITECTURE.md (System diagrams)
|
| 113 |
+
│ ├── PRESENTATION_SCRIPT.md (Demo script)
|
| 114 |
+
│ └── UI_SUMMARY.md (This file)
|
| 115 |
+
│
|
| 116 |
+
├── 🔧 Setup Tools
|
| 117 |
+
│ └── setup_ui.py (Automated installer)
|
| 118 |
+
│
|
| 119 |
+
└── 📦 Updated Files
|
| 120 |
+
└── requirements.txt (Added gradio + streamlit)
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
---
|
| 124 |
+
|
| 125 |
+
## 🎯 What Each Interface Looks Like
|
| 126 |
+
|
| 127 |
+
### Gradio Interface Features:
|
| 128 |
+
```
|
| 129 |
+
┌─────────────────────────────────────────┐
|
| 130 |
+
│ 🏥 Medical Q&A Bot │
|
| 131 |
+
│ Health Query Classifier & Retriever │
|
| 132 |
+
│ Team: David • Tarak • Sravani • etc. │
|
| 133 |
+
├─────────────────────────────────────────┤
|
| 134 |
+
│ │
|
| 135 |
+
│ [Enter your health query...] │
|
| 136 |
+
│ ┌─────────────────────────────────┐ │
|
| 137 |
+
│ │ Number of Results: [10] │ │
|
| 138 |
+
│ │ ☐ Use Reranker │ │
|
| 139 |
+
│ └─────────────────────────────────┘ │
|
| 140 |
+
│ │
|
| 141 |
+
│ [🔍 Analyze Query] │
|
| 142 |
+
│ │
|
| 143 |
+
├─────────────────────────────────────────┤
|
| 144 |
+
│ Classification Result: │
|
| 145 |
+
│ ✓ MEDICAL (95% confidence) │
|
| 146 |
+
│ - Medical: 95.2% ████████████ │
|
| 147 |
+
│ - Administrative: 4.8% █ │
|
| 148 |
+
├─────────────────────────────────────────┤
|
| 149 |
+
│ [📄 Formatted View] [📊 JSON View] │
|
| 150 |
+
│ │
|
| 151 |
+
│ Found 10 Relevant Documents │
|
| 152 |
+
│ ┌───────────────────────────────┐ │
|
| 153 |
+
│ │ Result #1: Eczema Treatment │ │
|
| 154 |
+
│ │ BM25: 0.85 Dense: 0.92 RRF: 1.2│ │
|
| 155 |
+
│ │ Text: Treatment options for...│ │
|
| 156 |
+
│ └───────────────────────────────┘ │
|
| 157 |
+
│ ... │
|
| 158 |
+
└─────────────────────────────────────────┘
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
### Streamlit Interface Features:
|
| 162 |
+
```
|
| 163 |
+
┌─────────────┬───────────────────────────┐
|
| 164 |
+
│ ⚙️ Settings │ 🏥 Medical Q&A Bot │
|
| 165 |
+
│ │ ═══════════════════════ │
|
| 166 |
+
│ Results: 10 │ │
|
| 167 |
+
│ ▁▁▁▁▁▁▁▁ │ [Query input box...] │
|
| 168 |
+
│ │ │
|
| 169 |
+
│ ☐ Reranker │ [🔍 Analyze Query] │
|
| 170 |
+
│ │ │
|
| 171 |
+
│ ☐ JSON View │ Classification: │
|
| 172 |
+
│ │ 🏥 MEDICAL │
|
| 173 |
+
│ Examples: │ │
|
| 174 |
+
│ • Rash... │ Confidence: │
|
| 175 |
+
│ • Vaccine...│ ████████████ 95% │
|
| 176 |
+
│ • Headache..│ │
|
| 177 |
+
│ │ Results: │
|
| 178 |
+
│ │ ┌─────────────────┐ │
|
| 179 |
+
│ │ │ Result #1 │ │
|
| 180 |
+
│ │ │ BM25 │ Dense │ │
|
| 181 |
+
│ │ │ 0.85 │ 0.92 │ │
|
| 182 |
+
│ │ └─────────────────┘ │
|
| 183 |
+
└─────────────┴───────────────────────────┘
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
---
|
| 187 |
+
|
| 188 |
+
## 💡 Demo Workflow
|
| 189 |
+
|
| 190 |
+
### Perfect Demo Sequence:
|
| 191 |
+
1. **Start** → Launch UI (`python app.py`)
|
| 192 |
+
2. **Medical Query** → "I have a rash on my hands..."
|
| 193 |
+
- Show classification
|
| 194 |
+
- Show retrieved documents
|
| 195 |
+
- Point out scores
|
| 196 |
+
3. **Admin Query** → "Can I get an appointment?"
|
| 197 |
+
- Show different classification
|
| 198 |
+
- No retrieval happens
|
| 199 |
+
4. **Settings** → Adjust results, toggle reranker
|
| 200 |
+
5. **Views** → Switch between formatted and JSON
|
| 201 |
+
|
| 202 |
+
---
|
| 203 |
+
|
| 204 |
+
## 🎓 For Your Presentation
|
| 205 |
+
|
| 206 |
+
### Talking Points:
|
| 207 |
+
✅ "We built a professional web interface using Gradio"
|
| 208 |
+
✅ "The system classifies queries in real-time"
|
| 209 |
+
✅ "For medical queries, it retrieves relevant research"
|
| 210 |
+
✅ "Uses hybrid search with BM25 and dense embeddings"
|
| 211 |
+
✅ "Optional reranking for improved accuracy"
|
| 212 |
+
✅ "Clean, intuitive user experience"
|
| 213 |
+
|
| 214 |
+
### What to Demo:
|
| 215 |
+
✅ Classification confidence scores
|
| 216 |
+
✅ Document retrieval results
|
| 217 |
+
✅ Different query types (medical vs admin)
|
| 218 |
+
✅ Settings adjustment (reranker, result count)
|
| 219 |
+
✅ Multiple view modes (formatted + JSON)
|
| 220 |
+
|
| 221 |
+
### Impressive Technical Details:
|
| 222 |
+
✅ Sentence transformer embeddings
|
| 223 |
+
✅ Neural network classifier
|
| 224 |
+
✅ FAISS vector search
|
| 225 |
+
✅ RRF fusion algorithm
|
| 226 |
+
✅ Cross-encoder reranking
|
| 227 |
+
✅ Professional UI framework
|
| 228 |
+
|
| 229 |
+
---
|
| 230 |
+
|
| 231 |
+
## 🛠️ Troubleshooting
|
| 232 |
+
|
| 233 |
+
### Common Issues:
|
| 234 |
+
|
| 235 |
+
**"Module not found: gradio"**
|
| 236 |
+
```powershell
|
| 237 |
+
pip install gradio
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
**"No corpora files found"**
|
| 241 |
+
```powershell
|
| 242 |
+
python -m adapters.build_corpora
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
**"Port already in use"**
|
| 246 |
+
```python
|
| 247 |
+
# Edit app.py line ~255
|
| 248 |
+
demo.launch(server_port=8080) # Change port
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
+
**"Models loading slowly"**
|
| 252 |
+
- This is normal on first run
|
| 253 |
+
- Models are cached afterward
|
| 254 |
+
- Takes 30-60 seconds initially
|
| 255 |
+
|
| 256 |
+
---
|
| 257 |
+
|
| 258 |
+
## 🌟 Why This Implementation is Great
|
| 259 |
+
|
| 260 |
+
### For Your Project:
|
| 261 |
+
✅ Professional appearance
|
| 262 |
+
✅ Easy to demonstrate
|
| 263 |
+
✅ Well-documented
|
| 264 |
+
✅ Production-ready foundation
|
| 265 |
+
✅ Impressive to stakeholders
|
| 266 |
+
|
| 267 |
+
### For Your Resume:
|
| 268 |
+
✅ Modern tech stack (Gradio, PyTorch, Transformers)
|
| 269 |
+
✅ Full-stack development (UI + ML backend)
|
| 270 |
+
✅ Healthcare application (impactful domain)
|
| 271 |
+
✅ Clean, maintainable code
|
| 272 |
+
✅ Comprehensive documentation
|
| 273 |
+
|
| 274 |
+
### For Future Development:
|
| 275 |
+
✅ Easy to extend
|
| 276 |
+
✅ Modular architecture
|
| 277 |
+
✅ Multiple deployment options
|
| 278 |
+
✅ API already available
|
| 279 |
+
✅ Scalable design
|
| 280 |
+
|
| 281 |
+
---
|
| 282 |
+
|
| 283 |
+
## 📊 File Sizes
|
| 284 |
+
|
| 285 |
+
```
|
| 286 |
+
app.py ~9 KB (Main UI)
|
| 287 |
+
app_streamlit.py ~7 KB (Alt UI)
|
| 288 |
+
QUICKSTART.md ~5 KB (Setup guide)
|
| 289 |
+
UI_README.md ~8 KB (Features)
|
| 290 |
+
UI_IMPLEMENTATION.md ~10 KB (Details)
|
| 291 |
+
ARCHITECTURE.md ~15 KB (Diagrams)
|
| 292 |
+
PRESENTATION_SCRIPT.md ~12 KB (Demo guide)
|
| 293 |
+
```
|
| 294 |
+
|
| 295 |
+
Total new documentation: **~66 KB of helpful guides!**
|
| 296 |
+
|
| 297 |
+
---
|
| 298 |
+
|
| 299 |
+
## 🎯 Next Steps
|
| 300 |
+
|
| 301 |
+
### Immediate (5 minutes):
|
| 302 |
+
1. Run `pip install gradio streamlit`
|
| 303 |
+
2. Launch UI with `python app.py`
|
| 304 |
+
3. Test with example queries
|
| 305 |
+
4. Familiarize yourself with features
|
| 306 |
+
|
| 307 |
+
### Short Term (1 hour):
|
| 308 |
+
1. Read through QUICKSTART.md
|
| 309 |
+
2. Test both Gradio and Streamlit interfaces
|
| 310 |
+
3. Prepare demo queries
|
| 311 |
+
4. Practice presentation flow
|
| 312 |
+
|
| 313 |
+
### Before Presentation:
|
| 314 |
+
1. Review PRESENTATION_SCRIPT.md
|
| 315 |
+
2. Test demo multiple times
|
| 316 |
+
3. Prepare backup slides
|
| 317 |
+
4. Assign team roles
|
| 318 |
+
5. Get excited! 🚀
|
| 319 |
+
|
| 320 |
+
---
|
| 321 |
+
|
| 322 |
+
## 🤝 Team Credits
|
| 323 |
+
|
| 324 |
+
**Built by:**
|
| 325 |
+
- David Gray
|
| 326 |
+
- Tarak Jha
|
| 327 |
+
- Sravani Segireddy
|
| 328 |
+
- Riley Millikan
|
| 329 |
+
- Kent R. Spillner
|
| 330 |
+
|
| 331 |
+
**Technologies Used:**
|
| 332 |
+
- Python 3.8+
|
| 333 |
+
- PyTorch
|
| 334 |
+
- Sentence-Transformers
|
| 335 |
+
- Gradio
|
| 336 |
+
- Streamlit
|
| 337 |
+
- FAISS
|
| 338 |
+
- BM25
|
| 339 |
+
- scikit-learn
|
| 340 |
+
|
| 341 |
+
---
|
| 342 |
+
|
| 343 |
+
## 🎉 You're All Set!
|
| 344 |
+
|
| 345 |
+
Your medical Q&A bot now has:
|
| 346 |
+
✅ Two professional web interfaces
|
| 347 |
+
✅ Complete documentation
|
| 348 |
+
✅ Easy launchers
|
| 349 |
+
✅ Presentation guide
|
| 350 |
+
✅ Demo script
|
| 351 |
+
✅ Architecture diagrams
|
| 352 |
+
|
| 353 |
+
**Everything you need for a successful demo and presentation!**
|
| 354 |
+
|
| 355 |
+
---
|
| 356 |
+
|
| 357 |
+
## 🚀 Quick Commands Reference
|
| 358 |
+
|
| 359 |
+
```powershell
|
| 360 |
+
# Install everything
|
| 361 |
+
pip install -r requirements.txt
|
| 362 |
+
|
| 363 |
+
# Build data
|
| 364 |
+
python -m adapters.build_corpora
|
| 365 |
+
|
| 366 |
+
# Launch Gradio UI (recommended)
|
| 367 |
+
python app.py
|
| 368 |
+
|
| 369 |
+
# Launch Streamlit UI (alternative)
|
| 370 |
+
streamlit run app_streamlit.py
|
| 371 |
+
|
| 372 |
+
# Run automated setup
|
| 373 |
+
python setup_ui.py
|
| 374 |
+
|
| 375 |
+
# Use launcher scripts
|
| 376 |
+
.\launch_ui.ps1
|
| 377 |
+
```
|
| 378 |
+
|
| 379 |
+
---
|
| 380 |
+
|
| 381 |
+
## 📞 Need Help?
|
| 382 |
+
|
| 383 |
+
1. Check QUICKSTART.md for setup issues
|
| 384 |
+
2. Check UI_README.md for feature questions
|
| 385 |
+
3. Check ARCHITECTURE.md for technical details
|
| 386 |
+
4. Check PRESENTATION_SCRIPT.md for demo help
|
| 387 |
+
5. Ask your team members!
|
| 388 |
+
|
| 389 |
+
---
|
| 390 |
+
|
| 391 |
+
## ✨ Final Notes
|
| 392 |
+
|
| 393 |
+
This implementation provides:
|
| 394 |
+
- **Professional quality** - Ready to show to professors, potential employers
|
| 395 |
+
- **Well-documented** - Easy for team members to understand
|
| 396 |
+
- **Extensible** - Can be built upon for future projects
|
| 397 |
+
- **Portfolio-worthy** - Great addition to your GitHub
|
| 398 |
+
|
| 399 |
+
**You've got an impressive project here. Go show it off! 🌟**
|
| 400 |
+
|
| 401 |
+
---
|
| 402 |
+
|
| 403 |
+
*Created: December 3, 2025*
|
| 404 |
+
*For: Health Query Classifier Group Project*
|
| 405 |
+
*By: Your friendly AI assistant*
|
__init__.py
ADDED
|
File without changes
|
adapters/build_corpora.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json, jsonlines, pathlib
|
| 2 |
+
import concurrent.futures
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
from math import ceil
|
| 6 |
+
from pubmed import download_pubmed
|
| 7 |
+
|
| 8 |
+
OUT = pathlib.Path("data/corpora")
|
| 9 |
+
OUT.mkdir(parents=True, exist_ok=True)
|
| 10 |
+
|
| 11 |
+
PUBMED_ARTICLES_PER_XML_FILE = 30000
|
| 12 |
+
|
| 13 |
+
def write_jsonl(path, rows):
|
| 14 |
+
print(f"Writing {len(rows)} records to {path}")
|
| 15 |
+
with jsonlines.open(path, "w") as out:
|
| 16 |
+
out.write_all(rows)
|
| 17 |
+
print(f"Finished writing {path}")
|
| 18 |
+
|
| 19 |
+
# 1) LasseRegin medical Q&A
|
| 20 |
+
def build_lasseregin():
|
| 21 |
+
print("Starting LasseRegin build...")
|
| 22 |
+
import urllib.request
|
| 23 |
+
url = "https://raw.githubusercontent.com/LasseRegin/medical-question-answer-data/master/icliniqQAs.json"
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
with urllib.request.urlopen(url) as response:
|
| 27 |
+
data = json.loads(response.read().decode("utf-8"))
|
| 28 |
+
except Exception as e:
|
| 29 |
+
print(f"Failed to download LasseRegin data: {e}")
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
rows = []
|
| 33 |
+
for i, r in enumerate(tqdm(data, desc="LasseRegin", leave=False)):
|
| 34 |
+
rows.append({
|
| 35 |
+
"id": f"icliniq:{i}",
|
| 36 |
+
"title": r.get("title",""),
|
| 37 |
+
"question": r.get("question",""),
|
| 38 |
+
"answer": r.get("answer",""),
|
| 39 |
+
"source": "icliniq"
|
| 40 |
+
})
|
| 41 |
+
write_jsonl(OUT / "medical_qa.jsonl", rows)
|
| 42 |
+
print("Completed LasseRegin build.")
|
| 43 |
+
|
| 44 |
+
# 2) MIRIAD-4.4M-split
|
| 45 |
+
def build_miriad(sample_size=200_000):
|
| 46 |
+
print(f"Starting MIRIAD build (sample_size={sample_size})...")
|
| 47 |
+
try:
|
| 48 |
+
ds = load_dataset("miriad/miriad-4.4M", num_proc=4, split="train")
|
| 49 |
+
|
| 50 |
+
ds = ds.shuffle(seed=42).select(range(min(sample_size, len(ds))))
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print(f"Failed to load MIRIAD dataset: {e}")
|
| 53 |
+
return
|
| 54 |
+
|
| 55 |
+
rows = []
|
| 56 |
+
for i, ex in enumerate(tqdm(ds, desc="miriad", leave=False)):
|
| 57 |
+
rows.append({
|
| 58 |
+
"id": f"miriad:{i}",
|
| 59 |
+
"title": ex.get("paper_title",""),
|
| 60 |
+
"question": ex.get("question", ""),
|
| 61 |
+
"answer": ex.get("passage_text", ""),
|
| 62 |
+
"year": ex.get("year",""),
|
| 63 |
+
"specialty": ex.get("specialty",""),
|
| 64 |
+
|
| 65 |
+
})
|
| 66 |
+
write_jsonl(OUT / "miriad_text.jsonl", rows)
|
| 67 |
+
print("Completed MIRIAD build.")
|
| 68 |
+
|
| 69 |
+
# 3) PubMed abstracts
|
| 70 |
+
def build_pubmed(max_records=500_000):
|
| 71 |
+
num_files = int(ceil(max_records / PUBMED_ARTICLES_PER_XML_FILE))
|
| 72 |
+
print(f"Starting PubMed build (num_files={num_files}, max_records={max_records})...")
|
| 73 |
+
|
| 74 |
+
download_pubmed(OUT / "pubmed.jsonl", num_files)
|
| 75 |
+
print("Completed PubMed build.")
|
| 76 |
+
|
| 77 |
+
# 4) UniDoc-Bench (QA)
|
| 78 |
+
def build_unidoc(max_items=1000):
|
| 79 |
+
print(f"Starting UniDoc build (max_items={max_items})...")
|
| 80 |
+
try:
|
| 81 |
+
ds = load_dataset("Salesforce/UniDoc-Bench", split="healthcare")
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"Failed to load UniDoc dataset: {e}")
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
rows = []
|
| 87 |
+
for i, ex in enumerate(tqdm(ds, desc="unidoc", leave=False)):
|
| 88 |
+
q = ex.get("question","") or ex.get("query","")
|
| 89 |
+
a = ex.get("answer","") or ""
|
| 90 |
+
pdf = ex.get("pdf_path") or ex.get("document_path") or ""
|
| 91 |
+
domain = ex.get("domain","")
|
| 92 |
+
rows.append({
|
| 93 |
+
"id": f"unidoc:{i}",
|
| 94 |
+
"title": f"{domain} PDF",
|
| 95 |
+
"question": q,
|
| 96 |
+
"answer": a,
|
| 97 |
+
"pdf_path": pdf
|
| 98 |
+
})
|
| 99 |
+
if i+1 >= max_items:
|
| 100 |
+
break
|
| 101 |
+
write_jsonl(OUT / "unidoc_qa.jsonl", rows)
|
| 102 |
+
print("Completed UniDoc build.")
|
| 103 |
+
|
| 104 |
+
def main():
|
| 105 |
+
print("Starting parallel corpora build...")
|
| 106 |
+
# Define tasks
|
| 107 |
+
tasks = [
|
| 108 |
+
(build_lasseregin, []),
|
| 109 |
+
(build_miriad, [1000]),
|
| 110 |
+
(build_pubmed, [500_000]),
|
| 111 |
+
|
| 112 |
+
(build_unidoc, [1000])
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
| 116 |
+
futures = [executor.submit(func, *args) for func, args in tasks]
|
| 117 |
+
for future in concurrent.futures.as_completed(futures):
|
| 118 |
+
try:
|
| 119 |
+
future.result()
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f"A task failed: {e}")
|
| 122 |
+
|
| 123 |
+
print("✅ All corpora built successfully in data/corpora/")
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
main()
|
adapters/pubmed.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import gzip
|
| 5 |
+
import json
|
| 6 |
+
import re
|
| 7 |
+
import subprocess
|
| 8 |
+
import xml.etree.ElementTree as ET
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from urllib.request import urlopen, urlretrieve
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
PUBMED_DATASET_BASE_URL = "https://ftp.ncbi.nlm.nih.gov/pubmed/baseline"
|
| 15 |
+
|
| 16 |
+
PUBMED_FILE_LIMIT = 10
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_pubmed_dataset_size():
|
| 20 |
+
try:
|
| 21 |
+
with urlopen(PUBMED_DATASET_BASE_URL) as response:
|
| 22 |
+
html = response.read().decode("utf-8")
|
| 23 |
+
|
| 24 |
+
files = re.findall(r"(pubmed\d+n\d+)\.xml\.gz(?!\.)", html)
|
| 25 |
+
unique_files = set(files)
|
| 26 |
+
|
| 27 |
+
return len(unique_files)
|
| 28 |
+
|
| 29 |
+
except Exception as e:
|
| 30 |
+
print(f"Unable to count PubMed files: {e}")
|
| 31 |
+
|
| 32 |
+
return 0
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def download_pubmed_xml(output_dir, num_files=1, year='25'):
|
| 36 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
total_dataset_size = get_pubmed_dataset_size()
|
| 39 |
+
|
| 40 |
+
files = []
|
| 41 |
+
pbar = tqdm(total=total_dataset_size, desc=f"Downloading {num_files}/{total_dataset_size} files in PubMed dataset")
|
| 42 |
+
|
| 43 |
+
for i in range(1, num_files + 1):
|
| 44 |
+
filename = f"pubmed{year}n{i:04d}.xml.gz"
|
| 45 |
+
filepath = os.path.join(output_dir, filename)
|
| 46 |
+
|
| 47 |
+
if not os.path.exists(filepath):
|
| 48 |
+
urlretrieve(f"{PUBMED_DATASET_BASE_URL}/{filename}", filepath)
|
| 49 |
+
|
| 50 |
+
pbar.update(1)
|
| 51 |
+
|
| 52 |
+
files.append(filepath)
|
| 53 |
+
|
| 54 |
+
pbar.close()
|
| 55 |
+
|
| 56 |
+
return files
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def parse_pubmed_to_jsonl(xml_files, output_jsonl):
|
| 60 |
+
with open(output_jsonl, 'w') as out:
|
| 61 |
+
for xml_file in xml_files:
|
| 62 |
+
print(f"Parsing {xml_file}...")
|
| 63 |
+
with gzip.open(xml_file, 'rt', encoding='utf-8') as f:
|
| 64 |
+
tree = ET.parse(f)
|
| 65 |
+
root = tree.getroot()
|
| 66 |
+
|
| 67 |
+
for article in tqdm(root.findall('.//PubmedArticle')):
|
| 68 |
+
pmid_elem = article.find('.//PMID')
|
| 69 |
+
title_elem = article.find('.//ArticleTitle')
|
| 70 |
+
abstract_elem = article.find('.//Abstract/AbstractText')
|
| 71 |
+
|
| 72 |
+
if pmid_elem is not None:
|
| 73 |
+
title = title_elem.text if title_elem is not None else ""
|
| 74 |
+
abstract = abstract_elem.text if abstract_elem is not None else ""
|
| 75 |
+
|
| 76 |
+
doc = {
|
| 77 |
+
'id': pmid_elem.text,
|
| 78 |
+
'title': title,
|
| 79 |
+
'contents': f"{title} {abstract}".strip()
|
| 80 |
+
}
|
| 81 |
+
out.write(json.dumps(doc) + '\n')
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def download_pubmed(output_jsonl, num_files=1):
|
| 85 |
+
if os.path.exists(output_jsonl):
|
| 86 |
+
print(f"Already downloaded PubMed dataset: {output_jsonl}")
|
| 87 |
+
|
| 88 |
+
return
|
| 89 |
+
|
| 90 |
+
xml_dir = os.path.join(os.path.dirname(output_jsonl), '../pubmed-xml')
|
| 91 |
+
xml_files = download_pubmed_xml(xml_dir, num_files=num_files)
|
| 92 |
+
parse_pubmed_to_jsonl(xml_files, output_jsonl)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def build_index_cmd(input_file, index_dir):
|
| 96 |
+
return [
|
| 97 |
+
"python", "-m", "pyserini.index.lucene",
|
| 98 |
+
"--collection", "JsonCollection",
|
| 99 |
+
"--input", os.path.dirname(input_file),
|
| 100 |
+
"--index", index_dir,
|
| 101 |
+
"--generator", "DefaultLuceneDocumentGenerator",
|
| 102 |
+
"--threads", "32",
|
| 103 |
+
"--storePositions",
|
| 104 |
+
"--storeDocvectors",
|
| 105 |
+
"--storeRaw",
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def build_index(input_file, index_dir, cmd_generator=build_index_cmd):
|
| 110 |
+
if os.path.exists(index_dir) and os.listdir(index_dir):
|
| 111 |
+
print(f"Skipping existing index: {index_dir}")
|
| 112 |
+
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
os.makedirs(os.path.dirname(index_dir) or '.', exist_ok=True)
|
| 116 |
+
|
| 117 |
+
cmd = cmd_generator(input_file, index_dir)
|
| 118 |
+
|
| 119 |
+
subprocess.run(cmd, check=True)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def main(base_data_dir="data", base_index_dir="indexes", num_files=1):
|
| 123 |
+
corpus_jsonl = os.path.join(base_data_dir, "pubmed", "corpus.jsonl")
|
| 124 |
+
index_dir = os.path.join(base_index_dir, "pubmed")
|
| 125 |
+
|
| 126 |
+
download_pubmed(corpus_jsonl, num_files=num_files)
|
| 127 |
+
|
| 128 |
+
build_index(corpus_jsonl, index_dir)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
main(num_files=PUBMED_FILE_LIMIT)
|
app_retrieval_cached.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Medical Q&A UI - BM25 + Dense Retrieval Models WITH DISK CACHING
|
| 3 |
+
This version caches the indexes to disk for fast startup (30 seconds vs 5-8 minutes!)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from typing import Dict, List
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import pickle
|
| 10 |
+
import hashlib
|
| 11 |
+
import json
|
| 12 |
+
from retriever.index_bm25 import BM25Index
|
| 13 |
+
from retriever.index_dense import DenseIndex
|
| 14 |
+
from retriever.ingest import load_jsonl
|
| 15 |
+
from retriever.rrf import rrf
|
| 16 |
+
from team.interfaces import Candidate
|
| 17 |
+
|
| 18 |
+
# Cache directory
|
| 19 |
+
CACHE_DIR = Path("cache")
|
| 20 |
+
CACHE_DIR.mkdir(exist_ok=True)
|
| 21 |
+
|
| 22 |
+
print("=" * 70)
|
| 23 |
+
print(" Medical Document Retrieval System (CACHED VERSION)")
|
| 24 |
+
print(" Using BM25 + Dense Embeddings + RRF Fusion")
|
| 25 |
+
print(" With disk caching for fast startup!")
|
| 26 |
+
print("=" * 70)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _default_corpora_config() -> Dict[str, dict]:
|
| 30 |
+
return {
|
| 31 |
+
"medical_qa": {"path": "data/corpora/medical_qa.jsonl",
|
| 32 |
+
"text_fields": ["question", "answer", "title"]},
|
| 33 |
+
"miriad": {"path": "data/corpora/miriad_text.jsonl",
|
| 34 |
+
"text_fields": ["question", "answer", "title"]},
|
| 35 |
+
"unidoc": {"path": "data/corpora/unidoc_qa.jsonl",
|
| 36 |
+
"text_fields": ["question", "answer", "title"]},
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _available(cfg: Dict[str, dict]) -> Dict[str, dict]:
|
| 41 |
+
return {k: v for k, v in cfg.items() if Path(v["path"]).exists()}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _get_cache_key(corpora_config: Dict[str, dict]) -> str:
|
| 45 |
+
"""Generate a unique cache key based on corpora config"""
|
| 46 |
+
config_str = json.dumps(corpora_config, sort_keys=True)
|
| 47 |
+
return hashlib.md5(config_str.encode()).hexdigest()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class CachedRetriever:
|
| 51 |
+
"""Retriever with disk caching for BM25 and Dense indexes"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, corpora_config: Dict[str, dict], use_reranker: bool = False):
|
| 54 |
+
self.corpora_config = corpora_config
|
| 55 |
+
self.use_reranker = use_reranker
|
| 56 |
+
self.cache_key = _get_cache_key(corpora_config)
|
| 57 |
+
|
| 58 |
+
# Cache file paths
|
| 59 |
+
self.bm25_cache = CACHE_DIR / f"bm25_{self.cache_key}.pkl"
|
| 60 |
+
self.dense_cache = CACHE_DIR / f"dense_{self.cache_key}.pkl"
|
| 61 |
+
self.docs_cache = CACHE_DIR / f"docs_{self.cache_key}.pkl"
|
| 62 |
+
|
| 63 |
+
# Load or build indexes
|
| 64 |
+
self.docs_all = self._load_or_build_docs()
|
| 65 |
+
self.bm25 = self._load_or_build_bm25()
|
| 66 |
+
self.dense = self._load_or_build_dense()
|
| 67 |
+
|
| 68 |
+
def _load_or_build_docs(self) -> List:
|
| 69 |
+
"""Load documents from cache or build from scratch"""
|
| 70 |
+
if self.docs_cache.exists():
|
| 71 |
+
print(f"Loading documents from cache... ({self.docs_cache.name})")
|
| 72 |
+
try:
|
| 73 |
+
with open(self.docs_cache, 'rb') as f:
|
| 74 |
+
docs_all = pickle.load(f)
|
| 75 |
+
print(f" ✓ Loaded {len(docs_all)} documents from cache")
|
| 76 |
+
return docs_all
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f" ✗ Cache load failed: {e}")
|
| 79 |
+
print(" → Rebuilding documents...")
|
| 80 |
+
|
| 81 |
+
print("Building documents from corpora files...")
|
| 82 |
+
docs_all = []
|
| 83 |
+
for name, cfg in self.corpora_config.items():
|
| 84 |
+
print(f" Loading {name}...")
|
| 85 |
+
docs = load_jsonl(cfg["path"], tuple(cfg.get("text_fields", ("question", "answer"))))
|
| 86 |
+
docs_all.extend(docs)
|
| 87 |
+
|
| 88 |
+
# Save to cache
|
| 89 |
+
print(f"Saving documents to cache... ({len(docs_all)} docs)")
|
| 90 |
+
with open(self.docs_cache, 'wb') as f:
|
| 91 |
+
pickle.dump(docs_all, f)
|
| 92 |
+
|
| 93 |
+
return docs_all
|
| 94 |
+
|
| 95 |
+
def _load_or_build_bm25(self) -> BM25Index:
|
| 96 |
+
"""Load BM25 index from cache or build from scratch"""
|
| 97 |
+
if self.bm25_cache.exists():
|
| 98 |
+
print(f"Loading BM25 index from cache... ({self.bm25_cache.name})")
|
| 99 |
+
try:
|
| 100 |
+
with open(self.bm25_cache, 'rb') as f:
|
| 101 |
+
bm25_index = pickle.load(f)
|
| 102 |
+
print(f" ✓ BM25 index loaded from cache")
|
| 103 |
+
return bm25_index
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f" ✗ Cache load failed: {e}")
|
| 106 |
+
print(" → Rebuilding BM25 index...")
|
| 107 |
+
|
| 108 |
+
print("Building BM25 index from scratch...")
|
| 109 |
+
bm25_index = BM25Index(self.docs_all)
|
| 110 |
+
|
| 111 |
+
# Save to cache
|
| 112 |
+
print(f"Saving BM25 index to cache...")
|
| 113 |
+
with open(self.bm25_cache, 'wb') as f:
|
| 114 |
+
pickle.dump(bm25_index, f)
|
| 115 |
+
|
| 116 |
+
return bm25_index
|
| 117 |
+
|
| 118 |
+
def _load_or_build_dense(self) -> DenseIndex:
|
| 119 |
+
"""Load Dense index from cache or build from scratch"""
|
| 120 |
+
if self.dense_cache.exists():
|
| 121 |
+
print(f"Loading Dense index from cache... ({self.dense_cache.name})")
|
| 122 |
+
try:
|
| 123 |
+
with open(self.dense_cache, 'rb') as f:
|
| 124 |
+
dense_index = pickle.load(f)
|
| 125 |
+
print(f" ✓ Dense index loaded from cache")
|
| 126 |
+
return dense_index
|
| 127 |
+
except Exception as e:
|
| 128 |
+
print(f" ✗ Cache load failed: {e}")
|
| 129 |
+
print(" → Rebuilding Dense index...")
|
| 130 |
+
|
| 131 |
+
print("Building Dense index from scratch (this takes 5-8 minutes)...")
|
| 132 |
+
dense_index = DenseIndex(self.docs_all)
|
| 133 |
+
|
| 134 |
+
# Save to cache
|
| 135 |
+
print(f"Saving Dense index to cache...")
|
| 136 |
+
with open(self.dense_cache, 'wb') as f:
|
| 137 |
+
pickle.dump(dense_index, f)
|
| 138 |
+
|
| 139 |
+
return dense_index
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# Initialize cached retriever (fast if cached, slow first time)
|
| 143 |
+
print("\nInitializing retrieval system...")
|
| 144 |
+
cfg = _available(_default_corpora_config())
|
| 145 |
+
if not cfg:
|
| 146 |
+
raise RuntimeError("No corpora files found in data/corpora. Build them first.")
|
| 147 |
+
|
| 148 |
+
retriever = CachedRetriever(corpora_config=cfg, use_reranker=False)
|
| 149 |
+
|
| 150 |
+
print("\n✓ Retrieval system ready!")
|
| 151 |
+
print(f" Total documents indexed: {len(retriever.docs_all):,}")
|
| 152 |
+
print("=" * 70)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def get_candidates_cached(query: str, k_retrieve: int = 50) -> List[Candidate]:
|
| 156 |
+
"""
|
| 157 |
+
Returns top-N fused candidates with component scores (bm25, dense, rrf).
|
| 158 |
+
Uses the cached retriever for fast queries.
|
| 159 |
+
"""
|
| 160 |
+
# Get separate result lists (doc, score)
|
| 161 |
+
bm = retriever.bm25.search(query, k=max(k_retrieve, 100))
|
| 162 |
+
de = retriever.dense.search(query, k=max(k_retrieve, 100))
|
| 163 |
+
|
| 164 |
+
# Maps for score lookup
|
| 165 |
+
bm_map = {d.id: float(s) for d, s in bm}
|
| 166 |
+
de_map = {d.id: float(s) for d, s in de}
|
| 167 |
+
|
| 168 |
+
# Fuse and pick candidate set
|
| 169 |
+
fused = rrf([bm, de], k=max(k_retrieve, 50))
|
| 170 |
+
|
| 171 |
+
# Compute RRF per candidate using rank positions
|
| 172 |
+
K = 60
|
| 173 |
+
bm_rank = {d.id: i for i, (d, _) in enumerate(bm)}
|
| 174 |
+
de_rank = {d.id: i for i, (d, _) in enumerate(de)}
|
| 175 |
+
|
| 176 |
+
out: List[Candidate] = []
|
| 177 |
+
for doc, _ in fused[:k_retrieve]:
|
| 178 |
+
rrf_score = 0.0
|
| 179 |
+
if doc.id in bm_rank:
|
| 180 |
+
rrf_score += 1.0 / (K + bm_rank[doc.id] + 1)
|
| 181 |
+
if doc.id in de_rank:
|
| 182 |
+
rrf_score += 1.0 / (K + de_rank[doc.id] + 1)
|
| 183 |
+
out.append(Candidate(
|
| 184 |
+
id=doc.id,
|
| 185 |
+
title=doc.title or "",
|
| 186 |
+
text=doc.text,
|
| 187 |
+
meta=doc.meta or {},
|
| 188 |
+
bm25=bm_map.get(doc.id, 0.0),
|
| 189 |
+
dense=de_map.get(doc.id, 0.0),
|
| 190 |
+
rrf=rrf_score,
|
| 191 |
+
))
|
| 192 |
+
# Baseline order: RRF
|
| 193 |
+
out.sort(key=lambda c: c.rrf, reverse=True)
|
| 194 |
+
return out
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def retrieve_documents(query, num_results=5):
|
| 198 |
+
"""Retrieve relevant medical documents using your team's models"""
|
| 199 |
+
if not query or not query.strip():
|
| 200 |
+
return """
|
| 201 |
+
<div style="padding: 20px; background-color: #e7f3ff; border-radius: 10px; border-left: 5px solid #2196f3;">
|
| 202 |
+
<h3 style="margin-top: 0; color: #0d47a1;">How to Use</h3>
|
| 203 |
+
<p style="margin: 0; color: #1565c0;">Enter a medical query and we'll find relevant documents using BM25 + Dense retrieval with RRF fusion.</p>
|
| 204 |
+
<p style="margin: 8px 0 0 0; color: #1565c0;"><strong>Example:</strong> "headache with blurred vision" or "symptoms of diabetes"</p>
|
| 205 |
+
</div>
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
try:
|
| 209 |
+
# Use cached retrieval system (fast!)
|
| 210 |
+
hits = get_candidates_cached(query=query, k_retrieve=num_results)
|
| 211 |
+
|
| 212 |
+
if not hits:
|
| 213 |
+
return """
|
| 214 |
+
<div style="padding: 20px; background-color: #fff3cd; border-radius: 10px; border-left: 5px solid #ffc107;">
|
| 215 |
+
<h3 style="margin-top: 0; color: #856404;">No Results Found</h3>
|
| 216 |
+
<p style="margin: 0; color: #856404;">Try rephrasing your query or using different medical terms.</p>
|
| 217 |
+
</div>
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
# Build results HTML
|
| 221 |
+
result_html = f"""
|
| 222 |
+
<div style="padding: 15px; background: linear-gradient(135deg, #d4edda 0%, #c3e6cb 100%); border-radius: 10px; margin-bottom: 20px; border-left: 5px solid #28a745;">
|
| 223 |
+
<h3 style="margin-top: 0; color: #155724;">Found {len(hits)} Relevant Medical Documents</h3>
|
| 224 |
+
<p style="margin: 0;"><strong>Retrieved using:</strong> BM25 + Dense Embeddings + RRF Fusion (CACHED)</p>
|
| 225 |
+
</div>
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
for i, hit in enumerate(hits, 1):
|
| 229 |
+
title = hit.title if hit.title and hit.title.strip() else None
|
| 230 |
+
source = hit.meta.get('source', 'Unknown') if hit.meta else 'Unknown'
|
| 231 |
+
|
| 232 |
+
# Check if we have separate question/answer fields in metadata
|
| 233 |
+
question = hit.meta.get('question', '') if hit.meta else ''
|
| 234 |
+
answer = hit.meta.get('answer', '') if hit.meta else ''
|
| 235 |
+
|
| 236 |
+
# If we have separate Q&A, format them nicely
|
| 237 |
+
if question and answer:
|
| 238 |
+
content_html = f"""
|
| 239 |
+
<div style="margin-bottom: 12px;">
|
| 240 |
+
<strong style="color: #1976d2;">Question:</strong>
|
| 241 |
+
<p style="margin: 5px 0 0 0; line-height: 1.6; color: #424242;">{question}</p>
|
| 242 |
+
</div>
|
| 243 |
+
<div>
|
| 244 |
+
<strong style="color: #388e3c;">Answer:</strong>
|
| 245 |
+
<p style="margin: 5px 0 0 0; line-height: 1.6; color: #424242;">{answer[:500] + ("..." if len(answer) > 500 else "")}</p>
|
| 246 |
+
</div>
|
| 247 |
+
"""
|
| 248 |
+
else:
|
| 249 |
+
# Fallback to combined text
|
| 250 |
+
text = hit.text[:500] + ("..." if len(hit.text) > 500 else "")
|
| 251 |
+
content_html = f'<p style="margin: 0; line-height: 1.7; color: #34495e;">{text}</p>'
|
| 252 |
+
|
| 253 |
+
# Display relevance scores
|
| 254 |
+
bm25_score = hit.bm25
|
| 255 |
+
dense_score = hit.dense
|
| 256 |
+
rrf_score = hit.rrf
|
| 257 |
+
|
| 258 |
+
# Build title HTML only if title exists
|
| 259 |
+
title_html = f'<h4 style="margin: 0 0 15px 0; color: #2c3e50;">{title}</h4>' if title else ''
|
| 260 |
+
|
| 261 |
+
result_html += f"""
|
| 262 |
+
<div style="border: 2px solid #dee2e6; padding: 20px; margin: 20px 0; border-radius: 10px; background-color: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
|
| 263 |
+
<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; margin: -20px -20px 20px -20px; border-radius: 8px 8px 0 0;">
|
| 264 |
+
<div style="display: flex; justify-content: space-between; align-items: center;">
|
| 265 |
+
<h4 style="margin: 0; color: white;">Document #{i}</h4>
|
| 266 |
+
<span style="background-color: rgba(255,255,255,0.2); padding: 4px 12px; border-radius: 12px; font-size: 0.85em; color: white;">
|
| 267 |
+
{source}
|
| 268 |
+
</span>
|
| 269 |
+
</div>
|
| 270 |
+
</div>
|
| 271 |
+
|
| 272 |
+
<div style="margin-bottom: 15px;">
|
| 273 |
+
{title_html}
|
| 274 |
+
{content_html}
|
| 275 |
+
</div>
|
| 276 |
+
|
| 277 |
+
<div style="padding-top: 12px; border-top: 1px solid #e9ecef;">
|
| 278 |
+
<div style="display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 10px;">
|
| 279 |
+
<div style="background-color: #e3f2fd; padding: 8px; border-radius: 5px; text-align: center;">
|
| 280 |
+
<div style="font-size: 0.75em; color: #1976d2; font-weight: bold;">BM25</div>
|
| 281 |
+
<div style="font-size: 1.1em; color: #0d47a1;">{bm25_score:.4f}</div>
|
| 282 |
+
</div>
|
| 283 |
+
<div style="background-color: #f3e5f5; padding: 8px; border-radius: 5px; text-align: center;">
|
| 284 |
+
<div style="font-size: 0.75em; color: #7b1fa2; font-weight: bold;">Dense</div>
|
| 285 |
+
<div style="font-size: 1.1em; color: #4a148c;">{dense_score:.4f}</div>
|
| 286 |
+
</div>
|
| 287 |
+
<div style="background-color: #e8f5e9; padding: 8px; border-radius: 5px; text-align: center;">
|
| 288 |
+
<div style="font-size: 0.75em; color: #388e3c; font-weight: bold;">RRF Fusion</div>
|
| 289 |
+
<div style="font-size: 1.1em; color: #1b5e20;">{rrf_score:.4f}</div>
|
| 290 |
+
</div>
|
| 291 |
+
</div>
|
| 292 |
+
</div>
|
| 293 |
+
</div>
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
return result_html
|
| 297 |
+
|
| 298 |
+
except Exception as e:
|
| 299 |
+
return f"""
|
| 300 |
+
<div style="padding: 20px; background-color: #f8d7da; border-radius: 10px; border-left: 5px solid #dc3545;">
|
| 301 |
+
<h3 style="margin-top: 0; color: #721c24;">Error</h3>
|
| 302 |
+
<p style="margin: 0; color: #721c24;">{str(e)}</p>
|
| 303 |
+
</div>
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# Create Gradio interface
|
| 308 |
+
with gr.Blocks(title="Medical Document Retrieval (Cached)") as demo:
|
| 309 |
+
gr.Markdown("""
|
| 310 |
+
# Medical Document Retrieval System (CACHED VERSION)
|
| 311 |
+
|
| 312 |
+
**Models:**
|
| 313 |
+
- BM25 Index (keyword-based retrieval)
|
| 314 |
+
- Dense Embeddings (embeddinggemma-300m-medical)
|
| 315 |
+
- RRF Fusion (combines both approaches)
|
| 316 |
+
|
| 317 |
+
### Features:
|
| 318 |
+
- Searches across 10,000+ medical documents
|
| 319 |
+
- Shows relevance scores from each model component
|
| 320 |
+
- Returns the most relevant medical information
|
| 321 |
+
""")
|
| 322 |
+
|
| 323 |
+
with gr.Row():
|
| 324 |
+
with gr.Column():
|
| 325 |
+
query_input = gr.Textbox(
|
| 326 |
+
label="Enter your medical query",
|
| 327 |
+
placeholder="Example: headache with blurred vision",
|
| 328 |
+
lines=2
|
| 329 |
+
)
|
| 330 |
+
num_results = gr.Slider(
|
| 331 |
+
minimum=1,
|
| 332 |
+
maximum=10,
|
| 333 |
+
value=5,
|
| 334 |
+
step=1,
|
| 335 |
+
label="Number of results to retrieve"
|
| 336 |
+
)
|
| 337 |
+
submit_btn = gr.Button("Retrieve Documents", variant="primary", size="lg")
|
| 338 |
+
|
| 339 |
+
output_html = gr.HTML(label="Search Results")
|
| 340 |
+
|
| 341 |
+
submit_btn.click(
|
| 342 |
+
fn=retrieve_documents,
|
| 343 |
+
inputs=[query_input, num_results],
|
| 344 |
+
outputs=output_html
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
gr.Examples(
|
| 348 |
+
examples=[
|
| 349 |
+
"headache with blurred vision",
|
| 350 |
+
"symptoms of diabetes",
|
| 351 |
+
"chest pain when exercising",
|
| 352 |
+
"treatment for high blood pressure",
|
| 353 |
+
"causes of chronic fatigue",
|
| 354 |
+
],
|
| 355 |
+
inputs=query_input,
|
| 356 |
+
label="Try these example queries:"
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
gr.Markdown("""
|
| 360 |
+
---
|
| 361 |
+
### Technical Details
|
| 362 |
+
- **BM25**: Statistical keyword matching (TF-IDF based)
|
| 363 |
+
- **Dense**: Semantic search using transformer embeddings
|
| 364 |
+
- **RRF Fusion**: Reciprocal Rank Fusion combines both methods
|
| 365 |
+
- **Caching**: Indexes saved to disk in `cache/` folder for fast reloading
|
| 366 |
+
|
| 367 |
+
*Note: First launch builds and caches indexes (5-8 min). After that, startup takes only ~30 seconds!*
|
| 368 |
+
""")
|
| 369 |
+
|
| 370 |
+
print("\nOpening web interface...")
|
| 371 |
+
print(" Local access: http://127.0.0.1:7863")
|
| 372 |
+
print(" Public link will be generated...")
|
| 373 |
+
print("=" * 70)
|
| 374 |
+
|
| 375 |
+
if __name__ == "__main__":
|
| 376 |
+
demo.launch(server_name="127.0.0.1", server_port=7863, share=True)
|
classifier/__init__.py
ADDED
|
File without changes
|
classifier/config.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
def load_env():
|
| 4 |
+
if os.path.exists("env.list"):
|
| 5 |
+
with open("env.list", "r") as f:
|
| 6 |
+
for line in f:
|
| 7 |
+
line = line.strip()
|
| 8 |
+
if line and not line.startswith("#"):
|
| 9 |
+
key, value = line.split("=", 1)
|
| 10 |
+
os.environ[key] = value
|
| 11 |
+
|
| 12 |
+
load_env()
|
| 13 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
classifier/head.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch
|
| 4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 5 |
+
|
| 6 |
+
class ClassifierHead(
|
| 7 |
+
nn.Module,
|
| 8 |
+
PyTorchModelHubMixin,
|
| 9 |
+
repo_url="https://huggingface.co/davidgray/health-query-triage",
|
| 10 |
+
pipeline_tag="text-classification",
|
| 11 |
+
library_name="PyTorch",
|
| 12 |
+
tags=["medical", "classification"],
|
| 13 |
+
):
|
| 14 |
+
def __init__(self, num_classes: int, embedding_dim: int = 768): # Embedding-Gemma-300M has a 768-dimensional output
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
self.linear_elu_stack = nn.Sequential(
|
| 18 |
+
nn.Linear(embedding_dim, 512),
|
| 19 |
+
nn.ELU(),
|
| 20 |
+
nn.Dropout(0.5),
|
| 21 |
+
nn.Linear(512, 512),
|
| 22 |
+
nn.ELU(),
|
| 23 |
+
nn.Dropout(0.5),
|
| 24 |
+
nn.Linear(512, num_classes),
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 28 |
+
|
| 29 |
+
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 30 |
+
"""
|
| 31 |
+
Calculates logits from the sentence embedding.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
features (Dict[str, torch.Tensor]): Output dictionary from the Sentence Transformer body,
|
| 35 |
+
containing 'sentence_embedding'.
|
| 36 |
+
Returns:
|
| 37 |
+
Dict[str, torch.Tensor]: Dictionary with the 'logits' key.
|
| 38 |
+
"""
|
| 39 |
+
embeddings = features['sentence_embedding']
|
| 40 |
+
logits = self.linear_elu_stack(embeddings)
|
| 41 |
+
return {"logits": logits}
|
| 42 |
+
|
| 43 |
+
def predict(self, embeddings: torch.Tensor) -> torch.Tensor:
|
| 44 |
+
"""
|
| 45 |
+
Classifies embeddings into integer labels in the range [0, num_classes).
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
embeddings (torch.Tensor): Tensor with shape [num_inputs, embedding_size].
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
torch.Tensor: Integer labels with shape [num_inputs].
|
| 52 |
+
"""
|
| 53 |
+
# Get probabilities and find the class with the highest probability
|
| 54 |
+
proba = self.predict_proba(embeddings)
|
| 55 |
+
return torch.argmax(proba, dim=-1)
|
| 56 |
+
|
| 57 |
+
def predict_proba(self, embeddings: torch.Tensor) -> torch.Tensor:
|
| 58 |
+
"""
|
| 59 |
+
Classifies embeddings into probabilities for each class (summing to 1).
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
embeddings (torch.Tensor): Tensor with shape [num_inputs, embedding_size].
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
torch.Tensor: Float probabilities with shape [num_inputs, num_classes].
|
| 66 |
+
"""
|
| 67 |
+
# Apply the forward pass of the head to get logits
|
| 68 |
+
self.eval()
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
logits = self.linear_elu_stack(embeddings)
|
| 71 |
+
# Convert logits to probabilities using Softmax
|
| 72 |
+
probabilities = self.softmax(logits)
|
| 73 |
+
self.train() # Set back to training mode
|
| 74 |
+
|
| 75 |
+
return probabilities
|
| 76 |
+
|
| 77 |
+
def get_loss_fn(self) -> nn.Module:
|
| 78 |
+
"""
|
| 79 |
+
Returns an initialized loss function for training.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
nn.Module: An initialized loss function (e.g., CrossEntropyLoss).
|
| 83 |
+
"""
|
| 84 |
+
# CrossEntropyLoss expects logits (raw scores) as input
|
| 85 |
+
return nn.CrossEntropyLoss()
|
classifier/infer.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from classifier.head import ClassifierHead
|
| 2 |
+
from classifier.utils import CATEGORIES, CHECKPOINT_PATH, DEVICE, get_models, CLASSIFIER_NAME, get_latest_checkpoint
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import pprint
|
| 6 |
+
import torch
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
|
| 9 |
+
def classifier_init(checkpoint_path: str | None = None, model_id: str | None = CLASSIFIER_NAME) -> (SentenceTransformer, ClassifierHead):
|
| 10 |
+
if checkpoint_path:
|
| 11 |
+
latest_checkpoint = get_latest_checkpoint(checkpoint_path)
|
| 12 |
+
print(f"Loading checkpoint from {latest_checkpoint}")
|
| 13 |
+
embedding_model, classifier = get_models(model_id=latest_checkpoint)
|
| 14 |
+
else:
|
| 15 |
+
embedding_model, classifier = get_models(model_id=model_id)
|
| 16 |
+
|
| 17 |
+
return embedding_model, classifier
|
| 18 |
+
|
| 19 |
+
def predict_query(
|
| 20 |
+
text: list[str],
|
| 21 |
+
embedding_model: SentenceTransformer,
|
| 22 |
+
classifier_head: ClassifierHead,
|
| 23 |
+
) -> dict:
|
| 24 |
+
"""
|
| 25 |
+
Runs the full inference pipeline: Text -> Embedding -> Classification.
|
| 26 |
+
"""
|
| 27 |
+
# Set models to evaluation mode
|
| 28 |
+
embedding_model.eval()
|
| 29 |
+
classifier_head.eval()
|
| 30 |
+
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
# Embed the text
|
| 33 |
+
embeddings = embedding_model.encode(
|
| 34 |
+
text,
|
| 35 |
+
convert_to_tensor=True,
|
| 36 |
+
device=DEVICE
|
| 37 |
+
).to(DEVICE)
|
| 38 |
+
|
| 39 |
+
# Calculate probabilities and prediction
|
| 40 |
+
probabilities = classifier_head.predict_proba(embeddings)
|
| 41 |
+
|
| 42 |
+
# Get the predicted index and confidence
|
| 43 |
+
predicted_indices = torch.argmax(probabilities, dim=1).unsqueeze(1)
|
| 44 |
+
confidences = torch.gather(probabilities, dim=1, index=predicted_indices).squeeze().tolist()
|
| 45 |
+
|
| 46 |
+
# Get the predicted label name
|
| 47 |
+
predicted_labels = [CATEGORIES[i] for i in predicted_indices]
|
| 48 |
+
|
| 49 |
+
return {
|
| 50 |
+
'prediction': predicted_labels,
|
| 51 |
+
'confidence': confidences,
|
| 52 |
+
'probabilities': probabilities.cpu().squeeze().tolist()
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
def test(local: bool = False):
|
| 56 |
+
embedding_model, classifier = classifier_init(checkpoint_path=CHECKPOINT_PATH if local else None)
|
| 57 |
+
|
| 58 |
+
queries = [
|
| 59 |
+
"Hi! I'm having a really bad rash on my hands. I'm pretty sure it's my excema flairing up. Is there anythign stronger than aquaphor I can use on it?",
|
| 60 |
+
"Hey is there any way I can get an appointment in the next month?",
|
| 61 |
+
"Hey is there any way I can get an appointment in the next month with a doctor?",
|
| 62 |
+
"I'm traveling to South America soon. Do I need to get any vaccines before I go?",
|
| 63 |
+
"I have this rash that popped up today.",
|
| 64 |
+
"How can I make this hosptial bill go away?",
|
| 65 |
+
"I'm so confused do I have to cover the full cost of this operation?",
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
pred = predict_query(
|
| 69 |
+
text=queries,
|
| 70 |
+
embedding_model=embedding_model,
|
| 71 |
+
classifier_head=classifier,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
pprint.pprint(pred, indent=4)
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
ap = argparse.ArgumentParser(
|
| 78 |
+
description="Inference on a classifier for triaging health queries"
|
| 79 |
+
)
|
| 80 |
+
ap.add_argument(
|
| 81 |
+
"--local", action="store_true",
|
| 82 |
+
help="Use local checkpoint"
|
| 83 |
+
)
|
| 84 |
+
args = ap.parse_args()
|
| 85 |
+
|
| 86 |
+
test(local=args.local)
|
classifier/modelcard_template.md
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
|
| 3 |
+
# Doc / guide: https://huggingface.co/docs/hub/model-cards
|
| 4 |
+
{{ card_data }}
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
# Model Card for {{ model_id | default("Model ID", true) }}
|
| 8 |
+
|
| 9 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 10 |
+
|
| 11 |
+
{{ model_summary | default("", true) }}
|
| 12 |
+
|
| 13 |
+
## Model Details
|
| 14 |
+
|
| 15 |
+
### Model Description
|
| 16 |
+
|
| 17 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 18 |
+
|
| 19 |
+
{{ model_description | default("", true) }}
|
| 20 |
+
|
| 21 |
+
- **Developed by:** {{ developers | default("[More Information Needed]", true)}}
|
| 22 |
+
- **Funded by [optional]:** {{ funded_by | default("[More Information Needed]", true)}}
|
| 23 |
+
- **Shared by [optional]:** {{ shared_by | default("[More Information Needed]", true)}}
|
| 24 |
+
- **Model type:** {{ model_type | default("[More Information Needed]", true)}}
|
| 25 |
+
- **Language(s) (NLP):** {{ language | default("[More Information Needed]", true)}}
|
| 26 |
+
- **License:** {{ license | default("[More Information Needed]", true)}}
|
| 27 |
+
- **Finetuned from model [optional]:** {{ base_model | default("[More Information Needed]", true)}}
|
| 28 |
+
|
| 29 |
+
### Model Sources [optional]
|
| 30 |
+
|
| 31 |
+
<!-- Provide the basic links for the model. -->
|
| 32 |
+
|
| 33 |
+
- **Repository:** {{ repo | default("[More Information Needed]", true)}}
|
| 34 |
+
- **Paper [optional]:** {{ paper | default("[More Information Needed]", true)}}
|
| 35 |
+
- **Demo [optional]:** {{ demo | default("[More Information Needed]", true)}}
|
| 36 |
+
|
| 37 |
+
## Uses
|
| 38 |
+
|
| 39 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 40 |
+
|
| 41 |
+
### Direct Use
|
| 42 |
+
|
| 43 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 44 |
+
|
| 45 |
+
{{ direct_use | default("[More Information Needed]", true)}}
|
| 46 |
+
|
| 47 |
+
### Downstream Use [optional]
|
| 48 |
+
|
| 49 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 50 |
+
|
| 51 |
+
{{ downstream_use | default("[More Information Needed]", true)}}
|
| 52 |
+
|
| 53 |
+
### Out-of-Scope Use
|
| 54 |
+
|
| 55 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 56 |
+
|
| 57 |
+
{{ out_of_scope_use | default("[More Information Needed]", true)}}
|
| 58 |
+
|
| 59 |
+
## Bias, Risks, and Limitations
|
| 60 |
+
|
| 61 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 62 |
+
|
| 63 |
+
{{ bias_risks_limitations | default("[More Information Needed]", true)}}
|
| 64 |
+
|
| 65 |
+
### Recommendations
|
| 66 |
+
|
| 67 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 68 |
+
|
| 69 |
+
{{ bias_recommendations | default("Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.", true)}}
|
| 70 |
+
|
| 71 |
+
## How to Get Started with the Model
|
| 72 |
+
|
| 73 |
+
Use the code below to get started with the model.
|
| 74 |
+
|
| 75 |
+
{{ get_started_code | default("[More Information Needed]", true)}}
|
| 76 |
+
|
| 77 |
+
## Training Details
|
| 78 |
+
|
| 79 |
+
### Training Data
|
| 80 |
+
|
| 81 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 82 |
+
|
| 83 |
+
{{ training_data | default("[More Information Needed]", true)}}
|
| 84 |
+
|
| 85 |
+
### Training Procedure
|
| 86 |
+
|
| 87 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 88 |
+
|
| 89 |
+
#### Preprocessing [optional]
|
| 90 |
+
|
| 91 |
+
{{ preprocessing | default("[More Information Needed]", true)}}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
#### Training Hyperparameters
|
| 95 |
+
|
| 96 |
+
- **Training regime:** {{ training_regime | default("[More Information Needed]", true)}} <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 97 |
+
|
| 98 |
+
#### Speeds, Sizes, Times [optional]
|
| 99 |
+
|
| 100 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 101 |
+
|
| 102 |
+
{{ speeds_sizes_times | default("[More Information Needed]", true)}}
|
| 103 |
+
|
| 104 |
+
## Evaluation
|
| 105 |
+
|
| 106 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 107 |
+
|
| 108 |
+
### Testing Data, Factors & Metrics
|
| 109 |
+
|
| 110 |
+
#### Testing Data
|
| 111 |
+
|
| 112 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 113 |
+
|
| 114 |
+
{{ testing_data | default("[More Information Needed]", true)}}
|
| 115 |
+
|
| 116 |
+
#### Factors
|
| 117 |
+
|
| 118 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 119 |
+
|
| 120 |
+
{{ testing_factors | default("[More Information Needed]", true)}}
|
| 121 |
+
|
| 122 |
+
#### Metrics
|
| 123 |
+
|
| 124 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 125 |
+
|
| 126 |
+
{{ testing_metrics | default("[More Information Needed]", true)}}
|
| 127 |
+
|
| 128 |
+
### Results
|
| 129 |
+
|
| 130 |
+
{{ results | default("[More Information Needed]", true)}}
|
| 131 |
+
|
| 132 |
+
#### Summary
|
| 133 |
+
|
| 134 |
+
{{ results_summary | default("", true) }}
|
| 135 |
+
|
| 136 |
+
## Model Examination [optional]
|
| 137 |
+
|
| 138 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 139 |
+
|
| 140 |
+
{{ model_examination | default("[More Information Needed]", true)}}
|
| 141 |
+
|
| 142 |
+
## Environmental Impact
|
| 143 |
+
|
| 144 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 145 |
+
|
| 146 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 147 |
+
|
| 148 |
+
- **Hardware Type:** {{ hardware_type | default("[More Information Needed]", true)}}
|
| 149 |
+
- **Hours used:** {{ hours_used | default("[More Information Needed]", true)}}
|
| 150 |
+
- **Cloud Provider:** {{ cloud_provider | default("[More Information Needed]", true)}}
|
| 151 |
+
- **Compute Region:** {{ cloud_region | default("[More Information Needed]", true)}}
|
| 152 |
+
- **Carbon Emitted:** {{ co2_emitted | default("[More Information Needed]", true)}}
|
| 153 |
+
|
| 154 |
+
## Technical Specifications [optional]
|
| 155 |
+
|
| 156 |
+
### Model Architecture and Objective
|
| 157 |
+
|
| 158 |
+
{{ model_specs | default("[More Information Needed]", true)}}
|
| 159 |
+
|
| 160 |
+
### Compute Infrastructure
|
| 161 |
+
|
| 162 |
+
{{ compute_infrastructure | default("[More Information Needed]", true)}}
|
| 163 |
+
|
| 164 |
+
#### Hardware
|
| 165 |
+
|
| 166 |
+
{{ hardware_requirements | default("[More Information Needed]", true)}}
|
| 167 |
+
|
| 168 |
+
#### Software
|
| 169 |
+
|
| 170 |
+
{{ software | default("[More Information Needed]", true)}}
|
| 171 |
+
|
| 172 |
+
## Citation [optional]
|
| 173 |
+
|
| 174 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 175 |
+
|
| 176 |
+
**BibTeX:**
|
| 177 |
+
|
| 178 |
+
{{ citation_bibtex | default("[More Information Needed]", true)}}
|
| 179 |
+
|
| 180 |
+
**APA:**
|
| 181 |
+
|
| 182 |
+
{{ citation_apa | default("[More Information Needed]", true)}}
|
| 183 |
+
|
| 184 |
+
## Glossary [optional]
|
| 185 |
+
|
| 186 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 187 |
+
|
| 188 |
+
{{ glossary | default("[More Information Needed]", true)}}
|
| 189 |
+
|
| 190 |
+
## More Information [optional]
|
| 191 |
+
|
| 192 |
+
{{ more_information | default("[More Information Needed]", true)}}
|
| 193 |
+
|
| 194 |
+
## Model Card Authors [optional]
|
| 195 |
+
|
| 196 |
+
{{ model_card_authors | default("[More Information Needed]", true)}}
|
| 197 |
+
|
| 198 |
+
## Model Card Contact
|
| 199 |
+
|
| 200 |
+
{{ model_card_contact | default("[More Information Needed]", true)}}
|
classifier/query_router.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Query Router System
|
| 3 |
+
|
| 4 |
+
This module integrates the medical/insurance classifier with the reason
|
| 5 |
+
classification system to provide intelligent routing of healthcare portal queries.
|
| 6 |
+
|
| 7 |
+
The router first determines if a query is medical or insurance-related, then
|
| 8 |
+
routes accordingly:
|
| 9 |
+
- Insurance queries -> Direct to insurance department
|
| 10 |
+
- Medical queries -> Reason classification -> Appropriate medical department routing
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
from typing import Dict, List, Optional, Tuple
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
# Add project root to path for imports
|
| 19 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 20 |
+
if str(REPO_ROOT) not in sys.path:
|
| 21 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 22 |
+
|
| 23 |
+
from classifier.infer import predict_query
|
| 24 |
+
from classifier.utils import get_models, CATEGORIES
|
| 25 |
+
from classifier.reason import predict_single_reason
|
| 26 |
+
from retriever.search import Retriever
|
| 27 |
+
from team.candidates import get_candidates
|
| 28 |
+
|
| 29 |
+
class HealthcareQueryRouter:
|
| 30 |
+
"""
|
| 31 |
+
Intelligent routing system for healthcare portal queries.
|
| 32 |
+
|
| 33 |
+
Routes queries through a two-stage process:
|
| 34 |
+
1. Medical vs Insurance classification
|
| 35 |
+
2. For medical queries: Reason classification for department routing
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self,
|
| 39 |
+
medical_model_path: Optional[str] = None,
|
| 40 |
+
use_retrieval: bool = True):
|
| 41 |
+
"""
|
| 42 |
+
Initialize the query router.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
medical_model_path: Path to trained medical/insurance classifier
|
| 46 |
+
use_retrieval: Whether to use retrieval system for medical queries
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
# Initialize medical/insurance classifier
|
| 50 |
+
try:
|
| 51 |
+
self.embedding_model, self.classifier_head = get_models()
|
| 52 |
+
|
| 53 |
+
# Load trained model if available
|
| 54 |
+
if medical_model_path and os.path.exists(medical_model_path):
|
| 55 |
+
import torch
|
| 56 |
+
state_dict = torch.load(medical_model_path, weights_only=True)
|
| 57 |
+
self.classifier_head.load_state_dict(state_dict)
|
| 58 |
+
print(f"Loaded medical/insurance classifier from {medical_model_path}")
|
| 59 |
+
else:
|
| 60 |
+
print("Using untrained medical/insurance classifier")
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f"Error initializing medical/insurance classifier: {e}")
|
| 64 |
+
raise
|
| 65 |
+
|
| 66 |
+
# Initialize retrieval system if requested
|
| 67 |
+
self.retriever = None
|
| 68 |
+
if use_retrieval:
|
| 69 |
+
try:
|
| 70 |
+
# Use default corpora configuration
|
| 71 |
+
corpora_config = {
|
| 72 |
+
"medical_qa": {
|
| 73 |
+
"path": "data/corpora/medical_qa.jsonl",
|
| 74 |
+
"text_fields": ["question", "answer", "title"],
|
| 75 |
+
},
|
| 76 |
+
"miriad": {
|
| 77 |
+
"path": "data/corpora/miriad_text.jsonl",
|
| 78 |
+
"text_fields": ["text", "title"],
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
# Only use available corpora
|
| 82 |
+
available_config = {k: v for k, v in corpora_config.items()
|
| 83 |
+
if Path(v["path"]).exists()}
|
| 84 |
+
|
| 85 |
+
if available_config:
|
| 86 |
+
self.retriever = Retriever(available_config)
|
| 87 |
+
print(f"Retrieval system initialized with {len(available_config)} corpora")
|
| 88 |
+
else:
|
| 89 |
+
print("No corpora files found. Retrieval disabled.")
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"Could not initialize retrieval system: {e}")
|
| 92 |
+
|
| 93 |
+
# Routing rules for insurance queries
|
| 94 |
+
self.insurance_routing = {
|
| 95 |
+
"department": "Insurance Department",
|
| 96 |
+
"priority": "normal",
|
| 97 |
+
"estimated_response": "1-2 business days",
|
| 98 |
+
"contact_method": "phone_or_email",
|
| 99 |
+
"description": "Insurance coverage, claims, and benefits inquiries"
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# Medical department routing based on reason categories
|
| 103 |
+
self.medical_department_routing = {
|
| 104 |
+
"ROUTINE_CARE": {
|
| 105 |
+
"department": "Primary Care",
|
| 106 |
+
"priority": "normal",
|
| 107 |
+
"estimated_response": "1-7 days",
|
| 108 |
+
"contact_method": "standard_scheduling",
|
| 109 |
+
"description": "Routine healthcare and maintenance visits"
|
| 110 |
+
},
|
| 111 |
+
"PAIN_CONDITIONS": {
|
| 112 |
+
"department": "Pain Management",
|
| 113 |
+
"priority": "high",
|
| 114 |
+
"estimated_response": "same day to 3 days",
|
| 115 |
+
"contact_method": "phone_preferred",
|
| 116 |
+
"description": "Pain-related conditions and discomfort"
|
| 117 |
+
},
|
| 118 |
+
"INJURIES": {
|
| 119 |
+
"department": "Urgent Care",
|
| 120 |
+
"priority": "high",
|
| 121 |
+
"estimated_response": "same day",
|
| 122 |
+
"contact_method": "phone_immediate",
|
| 123 |
+
"description": "Injuries, sprains, and trauma-related conditions"
|
| 124 |
+
},
|
| 125 |
+
"SKIN_CONDITIONS": {
|
| 126 |
+
"department": "Dermatology",
|
| 127 |
+
"priority": "normal",
|
| 128 |
+
"estimated_response": "3-7 days",
|
| 129 |
+
"contact_method": "standard_scheduling",
|
| 130 |
+
"description": "Skin-related issues and conditions"
|
| 131 |
+
},
|
| 132 |
+
"STRUCTURAL_ISSUES": {
|
| 133 |
+
"department": "Orthopedics",
|
| 134 |
+
"priority": "normal",
|
| 135 |
+
"estimated_response": "1-14 days",
|
| 136 |
+
"contact_method": "standard_scheduling",
|
| 137 |
+
"description": "Structural problems and musculoskeletal conditions"
|
| 138 |
+
},
|
| 139 |
+
"PROCEDURES": {
|
| 140 |
+
"department": "Surgical Services",
|
| 141 |
+
"priority": "normal",
|
| 142 |
+
"estimated_response": "3-14 days",
|
| 143 |
+
"contact_method": "scheduling_coordinator",
|
| 144 |
+
"description": "Surgical consultations and procedures"
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
def route_query(self, query: str, include_retrieval: bool = True) -> Dict:
|
| 149 |
+
"""
|
| 150 |
+
Route a healthcare query through the classification and routing system.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
query: The user's query text
|
| 154 |
+
include_retrieval: Whether to include retrieval results for medical queries
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Dictionary with routing decision, confidence, and additional context
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
# Step 1: Medical vs Insurance classification
|
| 161 |
+
medical_prediction = predict_query([query], self.embedding_model, self.classifier_head)
|
| 162 |
+
|
| 163 |
+
# Extract prediction details
|
| 164 |
+
primary_category = medical_prediction['prediction'][0]
|
| 165 |
+
confidence = medical_prediction['confidence'] if isinstance(medical_prediction['confidence'], float) else medical_prediction['confidence'][0]
|
| 166 |
+
probabilities = medical_prediction['probabilities']
|
| 167 |
+
|
| 168 |
+
routing_result = {
|
| 169 |
+
"query": query,
|
| 170 |
+
"primary_classification": primary_category,
|
| 171 |
+
"confidence": confidence,
|
| 172 |
+
"all_probabilities": {
|
| 173 |
+
CATEGORIES[i]: float(probabilities[i]) if isinstance(probabilities[0], list) else float(probabilities[i])
|
| 174 |
+
for i in range(len(CATEGORIES))
|
| 175 |
+
},
|
| 176 |
+
"routing_decision": None,
|
| 177 |
+
"reason_classification": None,
|
| 178 |
+
"retrieval_results": None,
|
| 179 |
+
"recommendations": []
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
# Step 2: Route based on classification
|
| 183 |
+
if primary_category.lower() == "medical":
|
| 184 |
+
routing_result["routing_decision"], routing_result["reason_classification"] = self._route_medical_query(query, include_retrieval)
|
| 185 |
+
else:
|
| 186 |
+
routing_result["routing_decision"] = self._route_insurance_query()
|
| 187 |
+
|
| 188 |
+
# Step 3: Add contextual recommendations
|
| 189 |
+
routing_result["recommendations"] = self._generate_recommendations(
|
| 190 |
+
primary_category, confidence, routing_result.get("reason_classification")
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
return routing_result
|
| 194 |
+
|
| 195 |
+
def _route_medical_query(self, query: str, include_retrieval: bool = True) -> Tuple[Dict, Dict]:
|
| 196 |
+
"""Route medical queries through reason classification."""
|
| 197 |
+
|
| 198 |
+
# Get reason classification
|
| 199 |
+
try:
|
| 200 |
+
reason_result = predict_single_reason(query)
|
| 201 |
+
reason_category = reason_result['category']
|
| 202 |
+
reason_confidence = reason_result['confidence']
|
| 203 |
+
reason_probabilities = reason_result['probabilities']
|
| 204 |
+
except Exception as e:
|
| 205 |
+
print(f"Reason classification failed: {e}")
|
| 206 |
+
# Fallback to general medical routing
|
| 207 |
+
reason_category = "ROUTINE_CARE"
|
| 208 |
+
reason_confidence = 0.5
|
| 209 |
+
reason_probabilities = {}
|
| 210 |
+
|
| 211 |
+
# Get department routing based on reason
|
| 212 |
+
routing = self.medical_department_routing.get(
|
| 213 |
+
reason_category,
|
| 214 |
+
self.medical_department_routing["ROUTINE_CARE"]
|
| 215 |
+
).copy()
|
| 216 |
+
|
| 217 |
+
# Add reason classification details
|
| 218 |
+
reason_classification = {
|
| 219 |
+
"category": reason_category,
|
| 220 |
+
"confidence": reason_confidence,
|
| 221 |
+
"probabilities": reason_probabilities
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
# Add retrieval results if available and requested
|
| 225 |
+
if include_retrieval and self.retriever:
|
| 226 |
+
try:
|
| 227 |
+
retrieval_results = self.retriever.retrieve(query, k=5, for_ui=True)
|
| 228 |
+
routing["retrieval_results"] = retrieval_results
|
| 229 |
+
except Exception as e:
|
| 230 |
+
print(f"Retrieval failed: {e}")
|
| 231 |
+
routing["retrieval_results"] = []
|
| 232 |
+
|
| 233 |
+
return routing, reason_classification
|
| 234 |
+
|
| 235 |
+
def _route_insurance_query(self) -> Dict:
|
| 236 |
+
"""Route insurance queries to insurance department."""
|
| 237 |
+
return self.insurance_routing.copy()
|
| 238 |
+
|
| 239 |
+
def _generate_recommendations(self, primary_category: str, confidence: float, reason_classification: Dict = None) -> List[str]:
|
| 240 |
+
"""Generate contextual recommendations based on classification."""
|
| 241 |
+
|
| 242 |
+
recommendations = []
|
| 243 |
+
|
| 244 |
+
# Low confidence warning
|
| 245 |
+
if confidence < 0.7:
|
| 246 |
+
recommendations.append(
|
| 247 |
+
"Classification confidence is low. Consider manual review or "
|
| 248 |
+
"asking the user to clarify their request."
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# Category-specific recommendations
|
| 252 |
+
if primary_category.lower() == "medical":
|
| 253 |
+
recommendations.extend([
|
| 254 |
+
"Consider asking follow-up questions about symptoms",
|
| 255 |
+
"Verify if this requires immediate attention",
|
| 256 |
+
"Check if patient has existing appointments or conditions"
|
| 257 |
+
])
|
| 258 |
+
|
| 259 |
+
# Reason-specific recommendations
|
| 260 |
+
if reason_classification:
|
| 261 |
+
reason_category = reason_classification.get('category')
|
| 262 |
+
if reason_category == "PAIN_CONDITIONS":
|
| 263 |
+
recommendations.append("Assess pain level and duration for urgency determination")
|
| 264 |
+
elif reason_category == "INJURIES":
|
| 265 |
+
recommendations.append("Determine if immediate medical attention is required")
|
| 266 |
+
elif reason_category == "PROCEDURES":
|
| 267 |
+
recommendations.append("Verify insurance pre-authorization requirements")
|
| 268 |
+
|
| 269 |
+
elif primary_category.lower() == "insurance":
|
| 270 |
+
recommendations.extend([
|
| 271 |
+
"Have patient account information ready",
|
| 272 |
+
"Verify current insurance information and benefits",
|
| 273 |
+
"Prepare to explain coverage details and requirements"
|
| 274 |
+
])
|
| 275 |
+
|
| 276 |
+
return recommendations
|
| 277 |
+
|
| 278 |
+
def batch_route_queries(self, queries: List[str]) -> List[Dict]:
|
| 279 |
+
"""Route multiple queries efficiently."""
|
| 280 |
+
return [self.route_query(query) for query in queries]
|
| 281 |
+
|
| 282 |
+
def get_routing_statistics(self, queries: List[str]) -> Dict:
|
| 283 |
+
"""Analyze routing patterns for a batch of queries."""
|
| 284 |
+
|
| 285 |
+
results = self.batch_route_queries(queries)
|
| 286 |
+
|
| 287 |
+
# Count categories
|
| 288 |
+
primary_counts = {}
|
| 289 |
+
reason_counts = {}
|
| 290 |
+
confidence_scores = []
|
| 291 |
+
|
| 292 |
+
for result in results:
|
| 293 |
+
# Primary classification counts
|
| 294 |
+
primary_category = result["primary_classification"]
|
| 295 |
+
primary_counts[primary_category] = primary_counts.get(primary_category, 0) + 1
|
| 296 |
+
confidence_scores.append(result["confidence"])
|
| 297 |
+
|
| 298 |
+
# Reason classification counts (for medical queries)
|
| 299 |
+
if result["reason_classification"]:
|
| 300 |
+
reason_category = result["reason_classification"]["category"]
|
| 301 |
+
reason_counts[reason_category] = reason_counts.get(reason_category, 0) + 1
|
| 302 |
+
|
| 303 |
+
return {
|
| 304 |
+
"total_queries": len(queries),
|
| 305 |
+
"primary_distribution": primary_counts,
|
| 306 |
+
"reason_distribution": reason_counts,
|
| 307 |
+
"average_confidence": sum(confidence_scores) / len(confidence_scores),
|
| 308 |
+
"low_confidence_queries": len([c for c in confidence_scores if c < 0.7]),
|
| 309 |
+
"primary_percentages": {
|
| 310 |
+
cat: (count / len(queries)) * 100
|
| 311 |
+
for cat, count in primary_counts.items()
|
| 312 |
+
},
|
| 313 |
+
"reason_percentages": {
|
| 314 |
+
cat: (count / len(queries)) * 100
|
| 315 |
+
for cat, count in reason_counts.items()
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def demo_router():
|
| 321 |
+
"""Demonstrate the query router functionality."""
|
| 322 |
+
|
| 323 |
+
print("Initializing Healthcare Query Router...")
|
| 324 |
+
router = HealthcareQueryRouter()
|
| 325 |
+
|
| 326 |
+
# Test queries covering different categories
|
| 327 |
+
test_queries = [
|
| 328 |
+
# Insurance queries
|
| 329 |
+
"My insurance claim was denied, can you help?",
|
| 330 |
+
"What does my insurance cover for this procedure?",
|
| 331 |
+
"I need to verify my insurance benefits",
|
| 332 |
+
|
| 333 |
+
# Medical queries - different reasons
|
| 334 |
+
"I have heel pain when I walk", # PAIN_CONDITIONS
|
| 335 |
+
"I need routine foot care", # ROUTINE_CARE
|
| 336 |
+
"I sprained my ankle playing sports", # INJURIES
|
| 337 |
+
"My toenail is ingrown and infected", # SKIN_CONDITIONS
|
| 338 |
+
"I have flat feet and need evaluation", # STRUCTURAL_ISSUES
|
| 339 |
+
"I need a cortisone injection", # PROCEDURES
|
| 340 |
+
]
|
| 341 |
+
|
| 342 |
+
print(f"\nRouting {len(test_queries)} test queries...\n")
|
| 343 |
+
|
| 344 |
+
for i, query in enumerate(test_queries, 1):
|
| 345 |
+
print(f"Query {i}: {query}")
|
| 346 |
+
|
| 347 |
+
result = router.route_query(query)
|
| 348 |
+
|
| 349 |
+
print(f" Primary: {result['primary_classification']} "
|
| 350 |
+
f"(confidence: {result['confidence']:.3f})")
|
| 351 |
+
|
| 352 |
+
if result['reason_classification']:
|
| 353 |
+
print(f" Reason: {result['reason_classification']['category']} "
|
| 354 |
+
f"(confidence: {result['reason_classification']['confidence']:.3f})")
|
| 355 |
+
|
| 356 |
+
print(f" Department: {result['routing_decision']['department']}")
|
| 357 |
+
print(f" Priority: {result['routing_decision']['priority']}")
|
| 358 |
+
print(f" Response Time: {result['routing_decision']['estimated_response']}")
|
| 359 |
+
|
| 360 |
+
if result['recommendations']:
|
| 361 |
+
print(f" Recommendation: {result['recommendations'][0]}")
|
| 362 |
+
|
| 363 |
+
print()
|
| 364 |
+
|
| 365 |
+
# Show routing statistics
|
| 366 |
+
print("Routing Statistics:")
|
| 367 |
+
stats = router.get_routing_statistics(test_queries)
|
| 368 |
+
|
| 369 |
+
print("Primary Classification:")
|
| 370 |
+
for category, percentage in stats['primary_percentages'].items():
|
| 371 |
+
print(f" {category}: {percentage:.1f}%")
|
| 372 |
+
|
| 373 |
+
if stats['reason_percentages']:
|
| 374 |
+
print("Reason Classification:")
|
| 375 |
+
for category, percentage in stats['reason_percentages'].items():
|
| 376 |
+
print(f" {category}: {percentage:.1f}%")
|
| 377 |
+
|
| 378 |
+
print(f"Average Confidence: {stats['average_confidence']:.3f}")
|
| 379 |
+
print(f"Low Confidence Queries: {stats['low_confidence_queries']}")
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
if __name__ == "__main__":
|
| 383 |
+
demo_router()
|
classifier/reason/README.md
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Healthcare Reason Classification System
|
| 2 |
+
|
| 3 |
+
This module implements a specialized classifier for healthcare visit reasons using real clinic data to classify patient queries into specific healthcare reason categories.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
The reason classifier addresses the challenge of routing medical healthcare queries to appropriate specialized departments. It classifies medical queries into specific reason categories based on actual healthcare visit data.
|
| 8 |
+
|
| 9 |
+
## Architecture
|
| 10 |
+
|
| 11 |
+
### Classification Categories
|
| 12 |
+
|
| 13 |
+
| Category | Description | Examples |
|
| 14 |
+
|----------|-------------|----------|
|
| 15 |
+
| `ROUTINE_CARE` | Routine healthcare, maintenance visits, general care | "I need routine foot care", "Regular nail care appointment" |
|
| 16 |
+
| `PAIN_CONDITIONS` | Various pain-related conditions and discomfort | "I have heel pain when I walk", "My ankle is sore" |
|
| 17 |
+
| `INJURIES` | Sprains, wounds, trauma-related conditions | "I sprained my ankle playing sports", "I have a wound that won't heal" |
|
| 18 |
+
| `SKIN_CONDITIONS` | Skin-related issues and conditions | "My toenail is ingrown and infected", "I have calluses on my feet" |
|
| 19 |
+
| `STRUCTURAL_ISSUES` | Structural problems and related conditions | "I have flat feet", "I need evaluation for plantar fasciitis" |
|
| 20 |
+
| `PROCEDURES` | Injections, surgical consultations, post-operative care | "I need a cortisone injection", "Post-surgical follow-up" |
|
| 21 |
+
|
| 22 |
+
### Technical Implementation
|
| 23 |
+
|
| 24 |
+
- **Base Model**: `sentence-transformers/embeddinggemma-300m-medical`
|
| 25 |
+
- **Architecture**: SetFit with frozen embeddings + trainable classification head
|
| 26 |
+
- **Training**: Real healthcare data from clinic appointment records
|
| 27 |
+
- **Integration**: Works as part of the complete healthcare routing system
|
| 28 |
+
|
| 29 |
+
## Quick Start
|
| 30 |
+
|
| 31 |
+
### 1. Train the Classifier
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
# Train with real healthcare data
|
| 35 |
+
python classifier/reason/train_reason.py
|
| 36 |
+
|
| 37 |
+
# The training script will:
|
| 38 |
+
# - Load real healthcare data from data/reason_for_visit_data.xlsx
|
| 39 |
+
# - Map reasons to categories using keyword matching
|
| 40 |
+
# - Train the classifier with frozen embeddings
|
| 41 |
+
# - Save the trained model to classifier/reason_checkpoints/
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
### 2. Use the CLI
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
# Classify a single reason query
|
| 48 |
+
python cli/reason_classifier_cli_new.py "I have heel pain when I walk"
|
| 49 |
+
|
| 50 |
+
# Interactive mode
|
| 51 |
+
python cli/reason_classifier_cli_new.py --interactive
|
| 52 |
+
|
| 53 |
+
# Batch processing
|
| 54 |
+
python cli/reason_classifier_cli_new.py --batch queries.txt --output results.json
|
| 55 |
+
|
| 56 |
+
# Use complete healthcare routing system
|
| 57 |
+
python cli/healthcare_classifier_cli.py "I need routine foot care"
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
### 3. Programmatic Usage
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
from classifier.reason import ReasonClassifier, predict_single_reason
|
| 64 |
+
|
| 65 |
+
# Using the main classifier class
|
| 66 |
+
classifier = ReasonClassifier()
|
| 67 |
+
predictions = classifier.predict(["I have heel pain when I walk"])
|
| 68 |
+
print(predictions[0]['category']) # Output: PAIN_CONDITIONS
|
| 69 |
+
|
| 70 |
+
# Using convenience function
|
| 71 |
+
result = predict_single_reason("I need routine foot care")
|
| 72 |
+
print(result['category']) # Output: ROUTINE_CARE
|
| 73 |
+
print(result['confidence']) # Confidence score
|
| 74 |
+
print(result['probabilities']) # All category probabilities
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## System Integration
|
| 78 |
+
|
| 79 |
+
### Complete Healthcare Routing Workflow
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
User Query
|
| 83 |
+
↓
|
| 84 |
+
Medical vs Insurance Classification
|
| 85 |
+
↓
|
| 86 |
+
┌─────────────────┬─────────────────┐
|
| 87 |
+
│ Insurance │ Medical │
|
| 88 |
+
│ Queries │ Queries │
|
| 89 |
+
│ ↓ │ ↓ │
|
| 90 |
+
│ Insurance │ Reason │
|
| 91 |
+
│ Department │ Classification │
|
| 92 |
+
│ │ ↓ │
|
| 93 |
+
│ │ • ROUTINE_CARE │
|
| 94 |
+
│ │ • PAIN_CONDITIONS │
|
| 95 |
+
│ │ • INJURIES │
|
| 96 |
+
│ │ • SKIN_CONDITIONS │
|
| 97 |
+
│ │ • STRUCTURAL_ISSUES │
|
| 98 |
+
│ │ • PROCEDURES │
|
| 99 |
+
└─────────────────┴─────────────────┘
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
### Integration with Healthcare System
|
| 103 |
+
|
| 104 |
+
The reason classifier integrates as part of the complete healthcare routing system:
|
| 105 |
+
|
| 106 |
+
1. **Primary Classification**: Medical vs Insurance queries
|
| 107 |
+
2. **Reason Classification**: Medical queries → Specific reason categories
|
| 108 |
+
3. **Department Routing**: Route to appropriate specialized departments
|
| 109 |
+
|
| 110 |
+
## Training Data Strategy
|
| 111 |
+
|
| 112 |
+
### Real Healthcare Data
|
| 113 |
+
|
| 114 |
+
The system uses actual healthcare clinic data:
|
| 115 |
+
|
| 116 |
+
```python
|
| 117 |
+
# Data source: data/reason_for_visit_data.xlsx
|
| 118 |
+
# Contains real patient visit reasons and appointment types
|
| 119 |
+
# Examples from actual data:
|
| 120 |
+
# - "Heel pain"
|
| 121 |
+
# - "Routine foot care"
|
| 122 |
+
# - "Ingrown toenail"
|
| 123 |
+
# - "Ankle sprain"
|
| 124 |
+
# - "Plantar fasciitis"
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### Category Mapping Strategy
|
| 128 |
+
|
| 129 |
+
The system uses keyword-based mapping to categorize real healthcare reasons:
|
| 130 |
+
|
| 131 |
+
```python
|
| 132 |
+
def map_reason_to_category(reason: str) -> int:
|
| 133 |
+
reason_lower = reason.lower()
|
| 134 |
+
|
| 135 |
+
# ROUTINE_CARE (routine care, maintenance visits)
|
| 136 |
+
if any(word in reason_lower for word in ['routine', 'nail care', 'calluses']):
|
| 137 |
+
return 0
|
| 138 |
+
|
| 139 |
+
# PAIN_CONDITIONS (various pain-related conditions)
|
| 140 |
+
elif any(word in reason_lower for word in ['pain', 'ache', 'sore']):
|
| 141 |
+
return 1
|
| 142 |
+
|
| 143 |
+
# ... other categories
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
## Performance Metrics
|
| 147 |
+
|
| 148 |
+
### Expected Performance
|
| 149 |
+
- **Accuracy**: Based on real healthcare data patterns
|
| 150 |
+
- **Categories**: 6 specialized healthcare reason categories
|
| 151 |
+
- **Confidence**: Variable based on training data quality
|
| 152 |
+
|
| 153 |
+
### Evaluation Framework
|
| 154 |
+
|
| 155 |
+
```bash
|
| 156 |
+
# Train and evaluate the model
|
| 157 |
+
python classifier/reason/train_reason.py
|
| 158 |
+
|
| 159 |
+
# Test the trained model
|
| 160 |
+
python classifier/reason/infer_reason.py
|
| 161 |
+
|
| 162 |
+
# Results include:
|
| 163 |
+
# - Training metrics
|
| 164 |
+
# - Category distribution
|
| 165 |
+
# - Example predictions with confidence scores
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
## File Structure
|
| 169 |
+
|
| 170 |
+
```
|
| 171 |
+
classifier/reason/
|
| 172 |
+
├── __init__.py # Package initialization and exports
|
| 173 |
+
├── README.md # This documentation
|
| 174 |
+
├── reason_classifier.py # Main ReasonClassifier class
|
| 175 |
+
├── infer_reason.py # Inference functions and utilities
|
| 176 |
+
└── train_reason.py # Training script and functions
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
## API Reference
|
| 180 |
+
|
| 181 |
+
### ReasonClassifier
|
| 182 |
+
|
| 183 |
+
```python
|
| 184 |
+
class ReasonClassifier:
|
| 185 |
+
def __init__(self, data_file: str = "data/reason_for_visit_data.xlsx")
|
| 186 |
+
def predict(self, queries: List[str]) -> List[Dict]
|
| 187 |
+
def train(self, train_data: pd.DataFrame = None, eval_data: Optional[pd.DataFrame] = None)
|
| 188 |
+
def save_model(self, path: str)
|
| 189 |
+
def load_model(self, path: str)
|
| 190 |
+
def create_real_dataset(self) -> pd.DataFrame
|
| 191 |
+
def analyze_real_data(self)
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
### Inference Functions
|
| 195 |
+
|
| 196 |
+
```python
|
| 197 |
+
def predict_single_reason(query: str) -> dict
|
| 198 |
+
def predict_reason_query(text: list[str], embedding_model, classifier_head) -> dict
|
| 199 |
+
def get_reason_models() -> tuple
|
| 200 |
+
def test_reason_classifier()
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
### Training Functions
|
| 204 |
+
|
| 205 |
+
```python
|
| 206 |
+
def get_reason_model(num_classes: int)
|
| 207 |
+
def get_reason_dataset() -> pd.DataFrame
|
| 208 |
+
def map_reason_to_category(reason: str) -> int
|
| 209 |
+
def preprocess_reason_data(df: pd.DataFrame) -> pd.DataFrame
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
## Data Requirements
|
| 213 |
+
|
| 214 |
+
### Healthcare Data Format
|
| 215 |
+
|
| 216 |
+
The system expects healthcare data in Excel format with these columns:
|
| 217 |
+
|
| 218 |
+
```
|
| 219 |
+
Required columns:
|
| 220 |
+
- "Reason For Visit": The primary reason for the healthcare visit
|
| 221 |
+
- "Appointment Type": Type of appointment (optional, used for context)
|
| 222 |
+
|
| 223 |
+
Example data:
|
| 224 |
+
| Reason For Visit | Appointment Type |
|
| 225 |
+
|------------------|------------------|
|
| 226 |
+
| Heel pain | Follow-up |
|
| 227 |
+
| Routine foot care| Maintenance |
|
| 228 |
+
| Ingrown toenail | New Patient |
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
## Deployment Considerations
|
| 232 |
+
|
| 233 |
+
### Production Readiness
|
| 234 |
+
|
| 235 |
+
1. **Model Persistence**: Trained models saved with timestamps in `classifier/reason_checkpoints/`
|
| 236 |
+
2. **Error Handling**: Graceful fallbacks for prediction failures
|
| 237 |
+
3. **Real Data Integration**: Uses actual healthcare clinic data
|
| 238 |
+
4. **Device Support**: CPU/GPU/MPS compatibility
|
| 239 |
+
|
| 240 |
+
### Scalability
|
| 241 |
+
|
| 242 |
+
- **Batch Processing**: Efficient handling of multiple queries
|
| 243 |
+
- **Integration**: Works with existing healthcare routing system
|
| 244 |
+
- **Checkpoints**: Automatic model saving with timestamps
|
| 245 |
+
|
| 246 |
+
## Future Enhancements
|
| 247 |
+
|
| 248 |
+
### Data Improvements
|
| 249 |
+
|
| 250 |
+
1. **Expanded Dataset**: Include more healthcare specialties
|
| 251 |
+
2. **Active Learning**: Improve model with real-world feedback
|
| 252 |
+
3. **Multi-language Support**: Support for non-English healthcare queries
|
| 253 |
+
|
| 254 |
+
### Advanced Features
|
| 255 |
+
|
| 256 |
+
1. **Confidence Calibration**: Improve confidence score reliability
|
| 257 |
+
2. **Hierarchical Classification**: Sub-categories within reason types
|
| 258 |
+
3. **Context Awareness**: Consider patient history and appointment context
|
| 259 |
+
|
| 260 |
+
## Troubleshooting
|
| 261 |
+
|
| 262 |
+
### Common Issues
|
| 263 |
+
|
| 264 |
+
1. **Data Loading Errors**: Ensure `data/reason_for_visit_data.xlsx` exists
|
| 265 |
+
2. **Low Confidence**: May indicate need for more training data or model retraining
|
| 266 |
+
3. **Import Errors**: Ensure all dependencies are installed and paths are correct
|
| 267 |
+
|
| 268 |
+
### Debug Mode
|
| 269 |
+
|
| 270 |
+
```python
|
| 271 |
+
# Test the classifier with sample queries
|
| 272 |
+
from classifier.reason.infer_reason import test_reason_classifier
|
| 273 |
+
test_reason_classifier()
|
| 274 |
+
|
| 275 |
+
# Check model predictions with probabilities
|
| 276 |
+
from classifier.reason import predict_single_reason
|
| 277 |
+
result = predict_single_reason("ambiguous query")
|
| 278 |
+
print(result['probabilities'])
|
| 279 |
+
```
|
| 280 |
+
|
| 281 |
+
### Model Training Issues
|
| 282 |
+
|
| 283 |
+
```bash
|
| 284 |
+
# Check if healthcare data is available
|
| 285 |
+
ls -la data/reason_for_visit_data.xlsx
|
| 286 |
+
|
| 287 |
+
# Verify model training
|
| 288 |
+
python classifier/reason/train_reason.py
|
| 289 |
+
|
| 290 |
+
# Test inference after training
|
| 291 |
+
python classifier/reason/infer_reason.py
|
| 292 |
+
```
|
| 293 |
+
|
| 294 |
+
## Contributing
|
| 295 |
+
|
| 296 |
+
### Adding New Categories
|
| 297 |
+
|
| 298 |
+
1. Update `REASON_CATEGORIES` in `reason_classifier.py`, `infer_reason.py`, and `train_reason.py`
|
| 299 |
+
2. Update category mapping logic in `map_reason_to_category()`
|
| 300 |
+
3. Retrain the model with new categories
|
| 301 |
+
4. Update documentation and examples
|
| 302 |
+
|
| 303 |
+
### Improving Training Data
|
| 304 |
+
|
| 305 |
+
1. Add more real healthcare examples to the dataset
|
| 306 |
+
2. Improve keyword mapping for better categorization
|
| 307 |
+
3. Implement more sophisticated NLP techniques for category assignment
|
| 308 |
+
|
| 309 |
+
## License
|
| 310 |
+
|
| 311 |
+
This module is part of the health-query-classifier project and follows the same licensing terms.
|
classifier/reason/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reason classification module for healthcare queries.
|
| 3 |
+
|
| 4 |
+
This module contains components for classifying healthcare visit reasons
|
| 5 |
+
into predefined categories based on real healthcare data.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .reason_classifier import ReasonClassifier, REASON_CATEGORIES
|
| 9 |
+
from .infer_reason import predict_reason_query, predict_single_reason, get_reason_models
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
'ReasonClassifier',
|
| 13 |
+
'REASON_CATEGORIES',
|
| 14 |
+
'predict_reason_query',
|
| 15 |
+
'predict_single_reason',
|
| 16 |
+
'get_reason_models'
|
| 17 |
+
]
|
classifier/reason/infer_reason.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference module for Healthcare Reason Classification
|
| 3 |
+
|
| 4 |
+
This module provides inference for the reason classification system,
|
| 5 |
+
separate from the medical/insurance classifier.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from ..head import ClassifierHead
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
import os
|
| 11 |
+
import pprint
|
| 12 |
+
import torch
|
| 13 |
+
from sentence_transformers import SentenceTransformer
|
| 14 |
+
|
| 15 |
+
# Reason-specific configuration
|
| 16 |
+
REASON_CATEGORIES = {
|
| 17 |
+
0: "ROUTINE_CARE",
|
| 18 |
+
1: "PAIN_CONDITIONS",
|
| 19 |
+
2: "INJURIES",
|
| 20 |
+
3: "SKIN_CONDITIONS",
|
| 21 |
+
4: "STRUCTURAL_ISSUES",
|
| 22 |
+
5: "PROCEDURES"
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
REASON_CHECKPOINT_PATH = "classifier/reason_checkpoints"
|
| 26 |
+
DATETIME_FORMAT = "%Y%m%d_%H%M%S"
|
| 27 |
+
MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
|
| 28 |
+
|
| 29 |
+
def get_device():
|
| 30 |
+
"""Get the best available device for inference."""
|
| 31 |
+
if torch.backends.mps.is_available():
|
| 32 |
+
return torch.device("mps")
|
| 33 |
+
elif torch.cuda.is_available():
|
| 34 |
+
return torch.device("cuda")
|
| 35 |
+
else:
|
| 36 |
+
return torch.device("cpu")
|
| 37 |
+
|
| 38 |
+
DEVICE = get_device()
|
| 39 |
+
|
| 40 |
+
def get_reason_models():
|
| 41 |
+
"""Get the embedding model and classifier head for reason inference."""
|
| 42 |
+
# Load embedding model
|
| 43 |
+
embedding_model = SentenceTransformer(
|
| 44 |
+
MODEL_NAME,
|
| 45 |
+
prompts={
|
| 46 |
+
'classification': 'task: healthcare reason classification | query: ',
|
| 47 |
+
'retrieval (query)': 'task: search result | query: ',
|
| 48 |
+
'retrieval (document)': 'title: {title | "none"} | text: ',
|
| 49 |
+
},
|
| 50 |
+
default_prompt_name='classification',
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Load classifier head (for 6 reason categories)
|
| 54 |
+
classifier_head = ClassifierHead(len(REASON_CATEGORIES))
|
| 55 |
+
|
| 56 |
+
return embedding_model.to(DEVICE), classifier_head.to(DEVICE)
|
| 57 |
+
|
| 58 |
+
def predict_reason_query(
|
| 59 |
+
text: list[str],
|
| 60 |
+
embedding_model: SentenceTransformer,
|
| 61 |
+
classifier_head: ClassifierHead,
|
| 62 |
+
) -> dict:
|
| 63 |
+
"""
|
| 64 |
+
Runs the full inference pipeline for reason classification: Text -> Embedding -> Classification.
|
| 65 |
+
"""
|
| 66 |
+
# Set models to evaluation mode
|
| 67 |
+
embedding_model.eval()
|
| 68 |
+
classifier_head.eval()
|
| 69 |
+
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
# Embed the text
|
| 72 |
+
embeddings = embedding_model.encode(
|
| 73 |
+
text,
|
| 74 |
+
convert_to_tensor=True,
|
| 75 |
+
device=DEVICE
|
| 76 |
+
).to(DEVICE)
|
| 77 |
+
|
| 78 |
+
# Calculate probabilities and prediction
|
| 79 |
+
probabilities = classifier_head.predict_proba(embeddings)
|
| 80 |
+
|
| 81 |
+
# Get the predicted index and confidence
|
| 82 |
+
predicted_indices = torch.argmax(probabilities, dim=1)
|
| 83 |
+
|
| 84 |
+
# Convert tensors to Python types safely
|
| 85 |
+
if predicted_indices.dim() == 0: # Single prediction
|
| 86 |
+
predicted_indices = [predicted_indices.item()]
|
| 87 |
+
else:
|
| 88 |
+
predicted_indices = predicted_indices.cpu().tolist()
|
| 89 |
+
|
| 90 |
+
# Get confidences
|
| 91 |
+
confidences = []
|
| 92 |
+
for i, idx in enumerate(predicted_indices):
|
| 93 |
+
conf = probabilities[i][idx].item() if probabilities.dim() > 1 else probabilities[idx].item()
|
| 94 |
+
confidences.append(conf)
|
| 95 |
+
|
| 96 |
+
# Get the predicted label names
|
| 97 |
+
predicted_labels = [REASON_CATEGORIES[i] for i in predicted_indices]
|
| 98 |
+
|
| 99 |
+
return {
|
| 100 |
+
'prediction': predicted_labels,
|
| 101 |
+
'confidence': confidences,
|
| 102 |
+
'probabilities': probabilities.cpu().tolist()
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
def predict_single_reason(query: str) -> dict:
|
| 106 |
+
"""Convenience function to predict a single reason query."""
|
| 107 |
+
try:
|
| 108 |
+
embedding_model, classifier_head = get_reason_models()
|
| 109 |
+
|
| 110 |
+
# Try to load the most recent trained checkpoint
|
| 111 |
+
if os.path.exists(REASON_CHECKPOINT_PATH):
|
| 112 |
+
for d in os.listdir(REASON_CHECKPOINT_PATH):
|
| 113 |
+
if d.endswith('.pt'):
|
| 114 |
+
checkpoint_path = f"{REASON_CHECKPOINT_PATH}/{d}"
|
| 115 |
+
try:
|
| 116 |
+
state_dict = torch.load(checkpoint_path, weights_only=True, map_location=DEVICE)
|
| 117 |
+
classifier_head.load_state_dict(state_dict)
|
| 118 |
+
print(f"Loaded trained weights from {checkpoint_path}")
|
| 119 |
+
break
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f"Could not load weights from {checkpoint_path}: {e}")
|
| 122 |
+
|
| 123 |
+
result = predict_reason_query([query], embedding_model, classifier_head)
|
| 124 |
+
|
| 125 |
+
# Extract values safely
|
| 126 |
+
prediction = result['prediction'][0] if isinstance(result['prediction'], list) else str(result['prediction'])
|
| 127 |
+
confidence = result['confidence'] if isinstance(result['confidence'], float) else (result['confidence'][0] if isinstance(result['confidence'], list) else float(result['confidence']))
|
| 128 |
+
|
| 129 |
+
# Handle probabilities - ensure it's a list
|
| 130 |
+
probabilities = result['probabilities']
|
| 131 |
+
if isinstance(probabilities, list) and len(probabilities) > 0:
|
| 132 |
+
if isinstance(probabilities[0], list):
|
| 133 |
+
probabilities = probabilities[0]
|
| 134 |
+
|
| 135 |
+
# Create probability dictionary
|
| 136 |
+
prob_dict = {}
|
| 137 |
+
for i, category in REASON_CATEGORIES.items():
|
| 138 |
+
if i < len(probabilities):
|
| 139 |
+
prob_dict[category] = float(probabilities[i])
|
| 140 |
+
else:
|
| 141 |
+
prob_dict[category] = 0.0
|
| 142 |
+
|
| 143 |
+
return {
|
| 144 |
+
'query': query,
|
| 145 |
+
'category': prediction,
|
| 146 |
+
'confidence': confidence,
|
| 147 |
+
'probabilities': prob_dict
|
| 148 |
+
}
|
| 149 |
+
except Exception as e:
|
| 150 |
+
# Return a default classification if the model fails
|
| 151 |
+
return {
|
| 152 |
+
'query': query,
|
| 153 |
+
'category': 'GENERAL_MEDICAL',
|
| 154 |
+
'confidence': 0.5,
|
| 155 |
+
'probabilities': {category: 1.0/len(REASON_CATEGORIES) for category in REASON_CATEGORIES.values()},
|
| 156 |
+
'error': str(e)
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
def test_reason_classifier():
|
| 160 |
+
"""Test the reason classifier with sample queries."""
|
| 161 |
+
latest = None
|
| 162 |
+
path = ""
|
| 163 |
+
|
| 164 |
+
# Try to load the most recent checkpoint
|
| 165 |
+
if os.path.exists(REASON_CHECKPOINT_PATH):
|
| 166 |
+
for d in os.listdir(REASON_CHECKPOINT_PATH):
|
| 167 |
+
if d.endswith('.pt'):
|
| 168 |
+
checkpoint_path = f"{REASON_CHECKPOINT_PATH}/{d}"
|
| 169 |
+
print(f"Found checkpoint: {checkpoint_path}")
|
| 170 |
+
path = checkpoint_path
|
| 171 |
+
break
|
| 172 |
+
|
| 173 |
+
if not path:
|
| 174 |
+
print("No trained checkpoints found. Using untrained model.")
|
| 175 |
+
else:
|
| 176 |
+
print("No checkpoint directory found. Using untrained model.")
|
| 177 |
+
|
| 178 |
+
embedding_model, classifier = get_reason_models()
|
| 179 |
+
|
| 180 |
+
# Load trained weights if available
|
| 181 |
+
if path and os.path.exists(path):
|
| 182 |
+
try:
|
| 183 |
+
state_dict = torch.load(path, weights_only=True, map_location=DEVICE)
|
| 184 |
+
classifier.load_state_dict(state_dict)
|
| 185 |
+
print(f"Loaded trained weights from {path}")
|
| 186 |
+
except Exception as e:
|
| 187 |
+
print(f"Could not load weights: {e}. Using untrained model.")
|
| 188 |
+
|
| 189 |
+
# Test queries for reason classification
|
| 190 |
+
queries = [
|
| 191 |
+
"I have heel pain when I walk",
|
| 192 |
+
"My toenail is ingrown and painful",
|
| 193 |
+
"I need routine foot care",
|
| 194 |
+
"I sprained my ankle playing sports",
|
| 195 |
+
"I have plantar fasciitis",
|
| 196 |
+
"I need a cortisone injection"
|
| 197 |
+
]
|
| 198 |
+
|
| 199 |
+
print("\nTesting reason classification:")
|
| 200 |
+
pred = predict_reason_query(
|
| 201 |
+
text=queries,
|
| 202 |
+
embedding_model=embedding_model,
|
| 203 |
+
classifier_head=classifier,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
pprint.pprint(pred, indent=4)
|
| 207 |
+
|
| 208 |
+
if __name__ == "__main__":
|
| 209 |
+
test_reason_classifier()
|
classifier/reason/reason_classifier.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Healthcare Reason for Visit Classifier
|
| 3 |
+
|
| 4 |
+
This module implements a classifier for healthcare clinic queries
|
| 5 |
+
using real healthcare data from clinic appointment records.
|
| 6 |
+
|
| 7 |
+
Categories based on the actual data:
|
| 8 |
+
- ROUTINE_CARE: Routine care, maintenance visits
|
| 9 |
+
- PAIN_CONDITIONS: Various pain-related conditions
|
| 10 |
+
- INJURIES: Sprains, wounds, trauma-related visits
|
| 11 |
+
- SKIN_CONDITIONS: Skin-related conditions and issues
|
| 12 |
+
- STRUCTURAL_ISSUES: Structural problems and conditions
|
| 13 |
+
- PROCEDURES: Injections, surgical consults, postop care
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import torch
|
| 18 |
+
import pandas as pd
|
| 19 |
+
import numpy as np
|
| 20 |
+
from typing import List, Dict, Tuple, Optional
|
| 21 |
+
from sentence_transformers import SentenceTransformer
|
| 22 |
+
from setfit import SetFitModel
|
| 23 |
+
from sklearn.model_selection import train_test_split
|
| 24 |
+
from sklearn.metrics import classification_report, confusion_matrix
|
| 25 |
+
from datasets import Dataset
|
| 26 |
+
import json
|
| 27 |
+
|
| 28 |
+
from ..head import ClassifierHead
|
| 29 |
+
|
| 30 |
+
# Healthcare reason categories based on real data analysis
|
| 31 |
+
REASON_CATEGORIES = {
|
| 32 |
+
0: "ROUTINE_CARE",
|
| 33 |
+
1: "PAIN_CONDITIONS",
|
| 34 |
+
2: "INJURIES",
|
| 35 |
+
3: "SKIN_CONDITIONS",
|
| 36 |
+
4: "STRUCTURAL_ISSUES",
|
| 37 |
+
5: "PROCEDURES"
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
CATEGORY_DESCRIPTIONS = {
|
| 41 |
+
"ROUTINE_CARE": "Routine healthcare, maintenance visits, general care",
|
| 42 |
+
"PAIN_CONDITIONS": "Various pain-related conditions and discomfort",
|
| 43 |
+
"INJURIES": "Sprains, wounds, trauma-related conditions",
|
| 44 |
+
"SKIN_CONDITIONS": "Skin-related issues and conditions",
|
| 45 |
+
"STRUCTURAL_ISSUES": "Structural problems and related conditions",
|
| 46 |
+
"PROCEDURES": "Injections, surgical consultations, post-operative care"
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
class ReasonClassifier:
|
| 50 |
+
"""
|
| 51 |
+
Healthcare Reason Classifier that uses real clinic data to classify
|
| 52 |
+
patient queries into specific healthcare reason categories.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, data_file: str = "data/reason_for_visit_data.xlsx"):
|
| 56 |
+
self.model_name = "sentence-transformers/embeddinggemma-300m-medical"
|
| 57 |
+
self.num_classes = len(REASON_CATEGORIES)
|
| 58 |
+
self.categories = REASON_CATEGORIES
|
| 59 |
+
self.data_file = data_file
|
| 60 |
+
self.model = None
|
| 61 |
+
self.device = self._get_device()
|
| 62 |
+
|
| 63 |
+
# Load and process real data
|
| 64 |
+
self.healthcare_df = self._load_data()
|
| 65 |
+
self._initialize_model()
|
| 66 |
+
|
| 67 |
+
def _get_device(self):
|
| 68 |
+
"""Get the best available device for training/inference."""
|
| 69 |
+
if torch.backends.mps.is_available():
|
| 70 |
+
return torch.device("mps")
|
| 71 |
+
elif torch.cuda.is_available():
|
| 72 |
+
return torch.device("cuda")
|
| 73 |
+
else:
|
| 74 |
+
return torch.device("cpu")
|
| 75 |
+
|
| 76 |
+
def _load_data(self) -> pd.DataFrame:
|
| 77 |
+
"""Load the real healthcare dataset."""
|
| 78 |
+
try:
|
| 79 |
+
df = pd.read_excel(self.data_file)
|
| 80 |
+
print(f"Loaded {len(df)} healthcare records from {self.data_file}")
|
| 81 |
+
print(f"Unique reasons: {df['Reason For Visit'].nunique()}")
|
| 82 |
+
return df
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Error loading data: {e}")
|
| 85 |
+
raise RuntimeError(f"Failed to load healthcare data from {self.data_file}")
|
| 86 |
+
|
| 87 |
+
def _initialize_model(self):
|
| 88 |
+
"""Initialize the model with the existing infrastructure."""
|
| 89 |
+
try:
|
| 90 |
+
model_body = SentenceTransformer(
|
| 91 |
+
self.model_name,
|
| 92 |
+
prompts={
|
| 93 |
+
'classification': 'task: healthcare reason classification | query: ',
|
| 94 |
+
'retrieval (query)': 'task: search result | query: ',
|
| 95 |
+
'retrieval (document)': 'title: {title | "none"} | text: ',
|
| 96 |
+
},
|
| 97 |
+
default_prompt_name='classification',
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
model_head = ClassifierHead(self.num_classes, embedding_dim=768)
|
| 101 |
+
self.model = SetFitModel(model_body, model_head)
|
| 102 |
+
self.model.freeze("body") # Freeze embedding weights
|
| 103 |
+
self.model = self.model.to(self.device)
|
| 104 |
+
|
| 105 |
+
print(f"Initialized ReasonClassifier on {self.device}")
|
| 106 |
+
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"Error initializing model: {e}")
|
| 109 |
+
raise RuntimeError("Failed to initialize reason classifier")
|
| 110 |
+
|
| 111 |
+
def _map_reason_to_category(self, reason: str) -> int:
|
| 112 |
+
"""
|
| 113 |
+
Map real healthcare reasons to categories using keyword matching.
|
| 114 |
+
Based on the actual data distribution.
|
| 115 |
+
"""
|
| 116 |
+
reason_lower = reason.lower()
|
| 117 |
+
|
| 118 |
+
# ROUTINE_CARE (routine foot care, nail care, calluses)
|
| 119 |
+
if any(word in reason_lower for word in ['routine', 'nail care', 'calluses']):
|
| 120 |
+
return 0
|
| 121 |
+
|
| 122 |
+
# PAIN_CONDITIONS (heel pain, ankle pain, foot pain, etc.)
|
| 123 |
+
if any(word in reason_lower for word in ['pain', 'ache', 'sore']):
|
| 124 |
+
return 1
|
| 125 |
+
|
| 126 |
+
# INJURIES (ankle sprain, wounds, trauma)
|
| 127 |
+
if any(word in reason_lower for word in ['sprain', 'wound', 'injury', 'trauma']):
|
| 128 |
+
return 2
|
| 129 |
+
|
| 130 |
+
# SKIN_CONDITIONS (ingrown toenail, calluses, skin issues)
|
| 131 |
+
if any(word in reason_lower for word in ['ingrown', 'toenail', 'callus', 'skin']):
|
| 132 |
+
return 3
|
| 133 |
+
|
| 134 |
+
# STRUCTURAL_ISSUES (flat feet, plantar fasciitis, achilles)
|
| 135 |
+
if any(word in reason_lower for word in ['flat feet', 'plantar', 'fasciitis', 'achilles', 'tendon']):
|
| 136 |
+
return 4
|
| 137 |
+
|
| 138 |
+
# PROCEDURES (injection, surgical consult, postop)
|
| 139 |
+
if any(word in reason_lower for word in ['injection', 'surgical', 'consult', 'postop', 'procedure']):
|
| 140 |
+
return 5
|
| 141 |
+
|
| 142 |
+
# Default to pain conditions (most common category)
|
| 143 |
+
return 1
|
| 144 |
+
|
| 145 |
+
def create_real_dataset(self) -> pd.DataFrame:
|
| 146 |
+
"""
|
| 147 |
+
Create training dataset from real healthcare data.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
training_data = []
|
| 151 |
+
|
| 152 |
+
for _, row in self.healthcare_df.iterrows():
|
| 153 |
+
reason = row['Reason For Visit']
|
| 154 |
+
appointment_type = row['Appointment Type']
|
| 155 |
+
|
| 156 |
+
# Map reason to category
|
| 157 |
+
category_id = self._map_reason_to_category(reason)
|
| 158 |
+
|
| 159 |
+
# Create enhanced text with context
|
| 160 |
+
enhanced_text = reason
|
| 161 |
+
if pd.notna(appointment_type):
|
| 162 |
+
enhanced_text += f" | {appointment_type}"
|
| 163 |
+
|
| 164 |
+
training_data.append({
|
| 165 |
+
'text': enhanced_text,
|
| 166 |
+
'label': category_id,
|
| 167 |
+
'category': self.categories[category_id],
|
| 168 |
+
'original_reason': reason,
|
| 169 |
+
'appointment_type': appointment_type
|
| 170 |
+
})
|
| 171 |
+
|
| 172 |
+
df = pd.DataFrame(training_data)
|
| 173 |
+
|
| 174 |
+
# Show category distribution
|
| 175 |
+
print("\nCategory distribution in training data:")
|
| 176 |
+
for cat_id, cat_name in self.categories.items():
|
| 177 |
+
count = len(df[df['label'] == cat_id])
|
| 178 |
+
percentage = (count / len(df)) * 100
|
| 179 |
+
print(f" {cat_name}: {count} samples ({percentage:.1f}%)")
|
| 180 |
+
|
| 181 |
+
return df.sample(frac=1).reset_index(drop=True) # Shuffle
|
| 182 |
+
|
| 183 |
+
def train(self, train_data: pd.DataFrame = None, eval_data: Optional[pd.DataFrame] = None,
|
| 184 |
+
epochs: int = 16, output_dir: str = "classifier/reason_checkpoints"):
|
| 185 |
+
"""Train the healthcare reason classifier."""
|
| 186 |
+
|
| 187 |
+
if train_data is None:
|
| 188 |
+
train_data = self.create_real_dataset()
|
| 189 |
+
|
| 190 |
+
if eval_data is None:
|
| 191 |
+
train_data, eval_data = train_test_split(train_data, test_size=0.2,
|
| 192 |
+
stratify=train_data['label'],
|
| 193 |
+
random_state=42)
|
| 194 |
+
|
| 195 |
+
train_dataset = Dataset.from_pandas(train_data)
|
| 196 |
+
eval_dataset = Dataset.from_pandas(eval_data)
|
| 197 |
+
|
| 198 |
+
from setfit import Trainer, TrainingArguments
|
| 199 |
+
|
| 200 |
+
args = TrainingArguments(
|
| 201 |
+
output_dir=output_dir,
|
| 202 |
+
num_epochs=(0, epochs), # Skip contrastive learning, only train head
|
| 203 |
+
eval_strategy='epoch',
|
| 204 |
+
eval_steps=100,
|
| 205 |
+
save_strategy='epoch',
|
| 206 |
+
logging_steps=50,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
trainer = Trainer(
|
| 210 |
+
model=self.model,
|
| 211 |
+
train_dataset=train_dataset,
|
| 212 |
+
eval_dataset=eval_dataset,
|
| 213 |
+
metric='accuracy',
|
| 214 |
+
column_mapping={"text": "text", "label": "label"},
|
| 215 |
+
args=args,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
print("Starting training...")
|
| 219 |
+
trainer.train()
|
| 220 |
+
|
| 221 |
+
# Evaluate
|
| 222 |
+
metrics = trainer.evaluate(eval_dataset)
|
| 223 |
+
print(f"Training completed. Final metrics: {metrics}")
|
| 224 |
+
|
| 225 |
+
return metrics
|
| 226 |
+
|
| 227 |
+
def predict(self, queries: List[str]) -> List[Dict]:
|
| 228 |
+
"""
|
| 229 |
+
Predict healthcare reason categories for a list of queries.
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
List of dictionaries with 'query', 'category', 'confidence', 'probabilities'
|
| 233 |
+
"""
|
| 234 |
+
if not self.model:
|
| 235 |
+
raise RuntimeError("Model not initialized. Train or load a model first.")
|
| 236 |
+
|
| 237 |
+
predictions = []
|
| 238 |
+
|
| 239 |
+
for query in queries:
|
| 240 |
+
# Get prediction using SetFit's built-in methods
|
| 241 |
+
pred_label = self.model.predict([query])[0]
|
| 242 |
+
pred_proba = self.model.predict_proba([query])[0]
|
| 243 |
+
|
| 244 |
+
category = self.categories[int(pred_label)]
|
| 245 |
+
confidence = float(pred_proba[int(pred_label)])
|
| 246 |
+
|
| 247 |
+
predictions.append({
|
| 248 |
+
'query': query,
|
| 249 |
+
'category': category,
|
| 250 |
+
'confidence': confidence,
|
| 251 |
+
'probabilities': {self.categories[i]: float(prob)
|
| 252 |
+
for i, prob in enumerate(pred_proba)}
|
| 253 |
+
})
|
| 254 |
+
|
| 255 |
+
return predictions
|
| 256 |
+
|
| 257 |
+
def save_model(self, path: str):
|
| 258 |
+
"""Save the trained model."""
|
| 259 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 260 |
+
self.model.save_pretrained(path)
|
| 261 |
+
|
| 262 |
+
# Save category mapping
|
| 263 |
+
with open(os.path.join(path, 'categories.json'), 'w') as f:
|
| 264 |
+
json.dump(self.categories, f)
|
| 265 |
+
|
| 266 |
+
print(f"Model saved to {path}")
|
| 267 |
+
|
| 268 |
+
def load_model(self, path: str):
|
| 269 |
+
"""Load a trained model."""
|
| 270 |
+
self.model = SetFitModel.from_pretrained(path)
|
| 271 |
+
self.model = self.model.to(self.device)
|
| 272 |
+
|
| 273 |
+
# Load category mapping
|
| 274 |
+
with open(os.path.join(path, 'categories.json'), 'r') as f:
|
| 275 |
+
self.categories = {int(k): v for k, v in json.load(f).items()}
|
| 276 |
+
|
| 277 |
+
print(f"Model loaded from {path}")
|
| 278 |
+
|
| 279 |
+
def evaluate_on_test_set(self, test_data: pd.DataFrame) -> Dict:
|
| 280 |
+
"""Evaluate the model on a test dataset."""
|
| 281 |
+
predictions = self.predict(test_data['text'].tolist())
|
| 282 |
+
|
| 283 |
+
y_true = test_data['label'].tolist()
|
| 284 |
+
y_pred = [list(self.categories.keys())[list(self.categories.values()).index(p['category'])]
|
| 285 |
+
for p in predictions]
|
| 286 |
+
|
| 287 |
+
# Classification report
|
| 288 |
+
report = classification_report(y_true, y_pred,
|
| 289 |
+
target_names=list(self.categories.values()),
|
| 290 |
+
output_dict=True)
|
| 291 |
+
|
| 292 |
+
# Confusion matrix
|
| 293 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 294 |
+
|
| 295 |
+
return {
|
| 296 |
+
'classification_report': report,
|
| 297 |
+
'confusion_matrix': cm.tolist(),
|
| 298 |
+
'accuracy': report['accuracy']
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
def analyze_real_data(self):
|
| 302 |
+
"""Analyze the real healthcare data to understand patterns."""
|
| 303 |
+
print("Real Data Analysis:")
|
| 304 |
+
print("=" * 50)
|
| 305 |
+
|
| 306 |
+
print(f"Total records: {len(self.healthcare_df)}")
|
| 307 |
+
print(f"Unique reasons: {self.healthcare_df['Reason For Visit'].nunique()}")
|
| 308 |
+
|
| 309 |
+
print("\nTop 15 reasons for visit:")
|
| 310 |
+
top_reasons = self.healthcare_df['Reason For Visit'].value_counts().head(15)
|
| 311 |
+
for reason, count in top_reasons.items():
|
| 312 |
+
category_id = self._map_reason_to_category(reason)
|
| 313 |
+
category_name = self.categories[category_id]
|
| 314 |
+
print(f" {reason}: {count} ({category_name})")
|
| 315 |
+
|
| 316 |
+
print(f"\nAppointment types:")
|
| 317 |
+
print(self.healthcare_df['Appointment Type'].value_counts())
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def main():
|
| 321 |
+
"""Example usage and training script for healthcare reason data."""
|
| 322 |
+
print("Initializing Healthcare Reason Classifier...")
|
| 323 |
+
|
| 324 |
+
# Initialize classifier with real data
|
| 325 |
+
classifier = ReasonClassifier()
|
| 326 |
+
|
| 327 |
+
# Analyze the real data
|
| 328 |
+
classifier.analyze_real_data()
|
| 329 |
+
|
| 330 |
+
# Create training dataset from real data
|
| 331 |
+
print("\nCreating training dataset from real healthcare data...")
|
| 332 |
+
dataset = classifier.create_real_dataset()
|
| 333 |
+
|
| 334 |
+
print(f"Dataset created with {len(dataset)} real examples")
|
| 335 |
+
|
| 336 |
+
# Train the model
|
| 337 |
+
print("\nTraining classifier...")
|
| 338 |
+
metrics = classifier.train(dataset, epochs=20)
|
| 339 |
+
|
| 340 |
+
# Save the model
|
| 341 |
+
model_path = "classifier/reason_model"
|
| 342 |
+
classifier.save_model(model_path)
|
| 343 |
+
|
| 344 |
+
# Test predictions on healthcare reason queries
|
| 345 |
+
test_queries = [
|
| 346 |
+
"I have heel pain when I walk",
|
| 347 |
+
"My toenail is ingrown and painful",
|
| 348 |
+
"I need routine foot care",
|
| 349 |
+
"I sprained my ankle playing sports",
|
| 350 |
+
"I have flat feet and need evaluation",
|
| 351 |
+
"I need a cortisone injection for my foot",
|
| 352 |
+
"I have plantar fasciitis",
|
| 353 |
+
"My foot wound is not healing"
|
| 354 |
+
]
|
| 355 |
+
|
| 356 |
+
print("\nTesting predictions on healthcare reason queries:")
|
| 357 |
+
predictions = classifier.predict(test_queries)
|
| 358 |
+
|
| 359 |
+
for pred in predictions:
|
| 360 |
+
print(f"Query: {pred['query']}")
|
| 361 |
+
print(f"Category: {pred['category']} (confidence: {pred['confidence']:.3f})")
|
| 362 |
+
print("---")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
if __name__ == "__main__":
|
| 366 |
+
main()
|
classifier/reason/train_reason.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training script for Healthcare Reason Classification
|
| 3 |
+
|
| 4 |
+
This script trains a classifier for healthcare visit reasons using real
|
| 5 |
+
healthcare data. It creates a separate system from the medical/insurance
|
| 6 |
+
classifier.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from sentence_transformers import SentenceTransformer
|
| 10 |
+
from setfit import SetFitModel, Trainer, TrainingArguments
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
# Add project root to path for imports
|
| 15 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 16 |
+
if str(REPO_ROOT) not in sys.path:
|
| 17 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 18 |
+
|
| 19 |
+
from classifier.head import ClassifierHead
|
| 20 |
+
import os
|
| 21 |
+
import pandas as pd
|
| 22 |
+
from sklearn.model_selection import train_test_split
|
| 23 |
+
from datasets import Dataset
|
| 24 |
+
import torch
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from datetime import datetime
|
| 27 |
+
|
| 28 |
+
# Reason-specific configuration
|
| 29 |
+
REASON_CATEGORIES = {
|
| 30 |
+
0: "ROUTINE_CARE",
|
| 31 |
+
1: "PAIN_CONDITIONS",
|
| 32 |
+
2: "INJURIES",
|
| 33 |
+
3: "SKIN_CONDITIONS",
|
| 34 |
+
4: "STRUCTURAL_ISSUES",
|
| 35 |
+
5: "PROCEDURES"
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
REASON_CHECKPOINT_PATH = "classifier/reason_checkpoints"
|
| 39 |
+
HEALTHCARE_DATA_PATH = "data/reason_for_visit_data.xlsx"
|
| 40 |
+
MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
|
| 41 |
+
|
| 42 |
+
def get_device():
|
| 43 |
+
"""Get the best available device for training/inference."""
|
| 44 |
+
if torch.backends.mps.is_available():
|
| 45 |
+
return torch.device("mps")
|
| 46 |
+
elif torch.cuda.is_available():
|
| 47 |
+
return torch.device("cuda")
|
| 48 |
+
else:
|
| 49 |
+
return torch.device("cpu")
|
| 50 |
+
|
| 51 |
+
def get_reason_model(num_classes: int):
|
| 52 |
+
"""Get model for reason classification."""
|
| 53 |
+
try:
|
| 54 |
+
model_body = SentenceTransformer(
|
| 55 |
+
MODEL_NAME,
|
| 56 |
+
prompts={
|
| 57 |
+
'classification': 'task: healthcare reason classification | query: ',
|
| 58 |
+
'retrieval (query)': 'task: search result | query: ',
|
| 59 |
+
'retrieval (document)': 'title: {title | "none"} | text: ',
|
| 60 |
+
},
|
| 61 |
+
default_prompt_name='classification',
|
| 62 |
+
)
|
| 63 |
+
# Freeze weights of embedding model
|
| 64 |
+
model_head = ClassifierHead(num_classes)
|
| 65 |
+
model = SetFitModel(model_body, model_head)
|
| 66 |
+
model.freeze("body")
|
| 67 |
+
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"Error loading model {MODEL_NAME}: {e}")
|
| 70 |
+
raise RuntimeError("Failed to load the embedding model.")
|
| 71 |
+
|
| 72 |
+
device = get_device()
|
| 73 |
+
print(f"Using device: {device}")
|
| 74 |
+
return model.to(device)
|
| 75 |
+
|
| 76 |
+
def get_reason_dataset() -> pd.DataFrame:
|
| 77 |
+
"""Load the healthcare reason dataset from Excel file."""
|
| 78 |
+
try:
|
| 79 |
+
if not os.path.exists(HEALTHCARE_DATA_PATH):
|
| 80 |
+
raise FileNotFoundError(f"Healthcare data file not found: {HEALTHCARE_DATA_PATH}")
|
| 81 |
+
|
| 82 |
+
print(f"Loading healthcare data from {HEALTHCARE_DATA_PATH}...")
|
| 83 |
+
df = pd.read_excel(HEALTHCARE_DATA_PATH)
|
| 84 |
+
print(f"Loaded {len(df)} healthcare records")
|
| 85 |
+
return df
|
| 86 |
+
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"Error loading healthcare dataset: {e}")
|
| 89 |
+
raise Exception(f"Failed to load healthcare data: {e}")
|
| 90 |
+
|
| 91 |
+
def map_reason_to_category(reason: str) -> int:
|
| 92 |
+
"""Map healthcare reasons to categories using keyword matching."""
|
| 93 |
+
reason_lower = reason.lower()
|
| 94 |
+
|
| 95 |
+
# ROUTINE_CARE (routine care, maintenance visits)
|
| 96 |
+
if any(word in reason_lower for word in ['routine', 'nail care', 'calluses', 'maintenance']):
|
| 97 |
+
return 0
|
| 98 |
+
|
| 99 |
+
# PAIN_CONDITIONS (various pain-related conditions)
|
| 100 |
+
elif any(word in reason_lower for word in ['pain', 'ache', 'sore', 'hurt']):
|
| 101 |
+
return 1
|
| 102 |
+
|
| 103 |
+
# INJURIES (sprains, wounds, trauma)
|
| 104 |
+
elif any(word in reason_lower for word in ['sprain', 'wound', 'injury', 'trauma', 'cut', 'bruise']):
|
| 105 |
+
return 2
|
| 106 |
+
|
| 107 |
+
# SKIN_CONDITIONS (skin-related issues)
|
| 108 |
+
elif any(word in reason_lower for word in ['ingrown', 'toenail', 'callus', 'corn', 'skin']):
|
| 109 |
+
return 3
|
| 110 |
+
|
| 111 |
+
# STRUCTURAL_ISSUES (structural problems)
|
| 112 |
+
elif any(word in reason_lower for word in ['flat feet', 'plantar', 'fasciitis', 'achilles', 'tendon', 'arch']):
|
| 113 |
+
return 4
|
| 114 |
+
|
| 115 |
+
# PROCEDURES (injections, surgical consultations)
|
| 116 |
+
elif any(word in reason_lower for word in ['injection', 'surgical', 'consult', 'postop', 'surgery', 'procedure']):
|
| 117 |
+
return 5
|
| 118 |
+
|
| 119 |
+
# Default to pain conditions (most common category)
|
| 120 |
+
else:
|
| 121 |
+
return 1
|
| 122 |
+
|
| 123 |
+
def preprocess_reason_data(df: pd.DataFrame) -> pd.DataFrame:
|
| 124 |
+
"""Preprocess the healthcare reason dataset for training."""
|
| 125 |
+
training_data = []
|
| 126 |
+
|
| 127 |
+
for _, row in df.iterrows():
|
| 128 |
+
reason = row['Reason For Visit']
|
| 129 |
+
appointment_type = row.get('Appointment Type', '')
|
| 130 |
+
|
| 131 |
+
# Map reason to category using keyword matching
|
| 132 |
+
category_id = map_reason_to_category(reason)
|
| 133 |
+
|
| 134 |
+
# Create enhanced text with context
|
| 135 |
+
enhanced_text = reason
|
| 136 |
+
if pd.notna(appointment_type) and appointment_type:
|
| 137 |
+
enhanced_text += f" | {appointment_type}"
|
| 138 |
+
|
| 139 |
+
training_data.append({
|
| 140 |
+
'text': enhanced_text,
|
| 141 |
+
'label': category_id,
|
| 142 |
+
'category': REASON_CATEGORIES[category_id],
|
| 143 |
+
'original_reason': reason
|
| 144 |
+
})
|
| 145 |
+
|
| 146 |
+
processed_df = pd.DataFrame(training_data)
|
| 147 |
+
|
| 148 |
+
# Show category distribution
|
| 149 |
+
print("\nReason category distribution in training data:")
|
| 150 |
+
for cat_id, cat_name in REASON_CATEGORIES.items():
|
| 151 |
+
count = len(processed_df[processed_df['label'] == cat_id])
|
| 152 |
+
percentage = (count / len(processed_df)) * 100
|
| 153 |
+
print(f" {cat_name}: {count} samples ({percentage:.1f}%)")
|
| 154 |
+
|
| 155 |
+
return processed_df
|
| 156 |
+
|
| 157 |
+
def main():
|
| 158 |
+
print("Healthcare Reason Classification - Training Pipeline")
|
| 159 |
+
print("=" * 60)
|
| 160 |
+
|
| 161 |
+
# Load and preprocess data
|
| 162 |
+
df = get_reason_dataset()
|
| 163 |
+
df = preprocess_reason_data(df)
|
| 164 |
+
|
| 165 |
+
# Get model
|
| 166 |
+
model = get_reason_model(len(REASON_CATEGORIES))
|
| 167 |
+
|
| 168 |
+
# Split data
|
| 169 |
+
train, test = train_test_split(
|
| 170 |
+
df, test_size=0.2, stratify=df['label'], random_state=42
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
print(f"\nData split:")
|
| 174 |
+
print(f" Training: {len(train)} samples")
|
| 175 |
+
print(f" Testing: {len(test)} samples")
|
| 176 |
+
|
| 177 |
+
train_dataset = Dataset.from_pandas(train)
|
| 178 |
+
test_dataset = Dataset.from_pandas(test)
|
| 179 |
+
|
| 180 |
+
# Ensure checkpoint directory exists
|
| 181 |
+
Path(REASON_CHECKPOINT_PATH).mkdir(parents=True, exist_ok=True)
|
| 182 |
+
|
| 183 |
+
# Training arguments
|
| 184 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 185 |
+
output_dir = f"{REASON_CHECKPOINT_PATH}/training_{timestamp}"
|
| 186 |
+
|
| 187 |
+
args = TrainingArguments(
|
| 188 |
+
output_dir=output_dir,
|
| 189 |
+
# Skip contrastive fine-tuning (body is frozen)
|
| 190 |
+
num_epochs=(0, 20),
|
| 191 |
+
eval_strategy='epoch',
|
| 192 |
+
eval_steps=100,
|
| 193 |
+
save_strategy='epoch',
|
| 194 |
+
logging_steps=50,
|
| 195 |
+
load_best_model_at_end=True,
|
| 196 |
+
metric_for_best_model='accuracy',
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
trainer = Trainer(
|
| 200 |
+
model=model,
|
| 201 |
+
train_dataset=train_dataset,
|
| 202 |
+
eval_dataset=test_dataset,
|
| 203 |
+
metric='accuracy',
|
| 204 |
+
column_mapping={"text": "text", "label": "label"},
|
| 205 |
+
args=args,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
print("\nStarting reason classification training...")
|
| 209 |
+
trainer.train()
|
| 210 |
+
|
| 211 |
+
# Evaluate
|
| 212 |
+
print("\nEvaluating reason classification model...")
|
| 213 |
+
metrics = trainer.evaluate(test_dataset)
|
| 214 |
+
print(f"Final evaluation metrics: {metrics}")
|
| 215 |
+
|
| 216 |
+
# Save the trained classifier head
|
| 217 |
+
model_save_path = f"{REASON_CHECKPOINT_PATH}/reason_classifier_head_{timestamp}.pt"
|
| 218 |
+
torch.save(model.model_head.state_dict(), model_save_path)
|
| 219 |
+
print(f"Reason classifier head saved to: {model_save_path}")
|
| 220 |
+
|
| 221 |
+
return metrics
|
| 222 |
+
|
| 223 |
+
if __name__ == "__main__":
|
| 224 |
+
main()
|
classifier/train.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from classifier.utils import CHECKPOINT_PATH, DATETIME_FORMAT, get_models, CATEGORIES, DEVICE, CLASSIFIER_NAME
|
| 2 |
+
from classifier.config import HF_TOKEN
|
| 3 |
+
from huggingface_hub import HfApi
|
| 4 |
+
from jinja2 import Template
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import datasets as ds
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import numpy as np
|
| 11 |
+
import os
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import torch
|
| 14 |
+
from torch.utils.data import DataLoader
|
| 15 |
+
|
| 16 |
+
def even_split(prefix: str, target: int, splits: int, total: int) -> str:
|
| 17 |
+
result = ""
|
| 18 |
+
target_amount_per_split = int(target / splits)
|
| 19 |
+
total_amount_per_split = int(total / splits)
|
| 20 |
+
|
| 21 |
+
for i in range(splits):
|
| 22 |
+
left = total_amount_per_split*i
|
| 23 |
+
right = left + target_amount_per_split
|
| 24 |
+
result += f"{prefix}[{int(left)}:{int(right)}]"
|
| 25 |
+
|
| 26 |
+
if i != splits - 1:
|
| 27 |
+
result += "+"
|
| 28 |
+
|
| 29 |
+
return result
|
| 30 |
+
|
| 31 |
+
def get_model_train_test():
|
| 32 |
+
# Login using e.g. `huggingface-cli login` to access this dataset
|
| 33 |
+
|
| 34 |
+
def add_static_label(row, column_name, label):
|
| 35 |
+
row[column_name] = label
|
| 36 |
+
return row
|
| 37 |
+
|
| 38 |
+
# Miriad
|
| 39 |
+
train_split = even_split("train", 50000, 100, 4470000)
|
| 40 |
+
miriad = ds.load_dataset("tomaarsen/miriad-4.4M-split", split={"train":train_split, "test": "test", "validation": "eval"})
|
| 41 |
+
miriad = miriad.rename_column("question", "text")
|
| 42 |
+
miriad = miriad.remove_columns("passage_text")
|
| 43 |
+
miriad = miriad.map(add_static_label, fn_kwargs={"column_name": "label", "label": "medical"})
|
| 44 |
+
# print(miriad)
|
| 45 |
+
|
| 46 |
+
# Insurance
|
| 47 |
+
train_split = even_split("train", 5000, 20, 21300)
|
| 48 |
+
insurance = ds.load_dataset("deccan-ai/insuranceQA-v2", split={"train":train_split, "test":"test", "validation":"validation"})
|
| 49 |
+
insurance = insurance.rename_column("input", "text")
|
| 50 |
+
insurance = insurance.remove_columns(["output"])
|
| 51 |
+
insurance = insurance.map(add_static_label, fn_kwargs={"column_name": "label", "label": "insurance"})
|
| 52 |
+
# print(insurance)
|
| 53 |
+
|
| 54 |
+
# Interleave datasets (mix the datasets into one randomly)
|
| 55 |
+
train = ds.interleave_datasets([miriad["train"], insurance["train"]], stopping_strategy="all_exhausted")
|
| 56 |
+
_ , unique_indices = np.unique(train["text"], return_index=True, axis=0)
|
| 57 |
+
train = train.select(unique_indices.tolist())
|
| 58 |
+
test = ds.interleave_datasets([miriad["test"], insurance["test"]], stopping_strategy="all_exhausted")
|
| 59 |
+
_ , unique_indices = np.unique(test["text"], return_index=True, axis=0)
|
| 60 |
+
test = test.select(unique_indices.tolist())
|
| 61 |
+
validation = ds.interleave_datasets([miriad["validation"], insurance["validation"]], stopping_strategy="all_exhausted")
|
| 62 |
+
_ , unique_indices = np.unique(validation["text"], return_index=True, axis=0)
|
| 63 |
+
validation = validation.select(unique_indices.tolist())
|
| 64 |
+
|
| 65 |
+
print(f"train: {len(train)}, validation: {len(validation)}, test: {len(test)}")
|
| 66 |
+
|
| 67 |
+
# Get models
|
| 68 |
+
embedding_model, classifier = get_models()
|
| 69 |
+
|
| 70 |
+
return embedding_model, classifier, train, test, validation, CATEGORIES
|
| 71 |
+
|
| 72 |
+
def test_loop(dataloader, model, loss_fn):
|
| 73 |
+
# Set the model to evaluation mode - important for batch normalization and dropout layers
|
| 74 |
+
# Unnecessary in this situation but added for best practices
|
| 75 |
+
model.eval()
|
| 76 |
+
size = len(dataloader.dataset)
|
| 77 |
+
num_batches = len(dataloader)
|
| 78 |
+
test_loss, correct = 0, 0
|
| 79 |
+
|
| 80 |
+
# Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
|
| 81 |
+
# also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
for batch in dataloader:
|
| 84 |
+
pred = model(batch)['logits']
|
| 85 |
+
test_loss += loss_fn(pred, batch['label']).item()
|
| 86 |
+
correct += (pred.argmax(1) == batch['label']).type(torch.float).sum().item()
|
| 87 |
+
|
| 88 |
+
avg_loss = test_loss / num_batches
|
| 89 |
+
accuracy = correct / size
|
| 90 |
+
|
| 91 |
+
return avg_loss, accuracy
|
| 92 |
+
|
| 93 |
+
def train_loop(dataloader, model, loss_fn, optimizer, batch_size = 64, epochs = 10):
|
| 94 |
+
size = len(dataloader.dataset)
|
| 95 |
+
total_loss = 0
|
| 96 |
+
batch_losses = []
|
| 97 |
+
|
| 98 |
+
# Set models to training mode
|
| 99 |
+
model.train()
|
| 100 |
+
|
| 101 |
+
for iteration, batch in enumerate(dataloader):
|
| 102 |
+
# --- 1. Zero Gradients ---
|
| 103 |
+
# Only zero gradients for the parameters you want to update (the classifier head)
|
| 104 |
+
optimizer.zero_grad()
|
| 105 |
+
|
| 106 |
+
# --- 3. Forward Pass: Embeddings -> Logits ---
|
| 107 |
+
# The classifier head takes the embeddings from the body
|
| 108 |
+
pred = model(batch)['logits']
|
| 109 |
+
|
| 110 |
+
# --- 4. Calculate Loss ---
|
| 111 |
+
loss = loss_fn(pred, batch['label'])
|
| 112 |
+
|
| 113 |
+
# --- 5. Backward Pass & Update ---
|
| 114 |
+
loss.backward()
|
| 115 |
+
optimizer.step()
|
| 116 |
+
|
| 117 |
+
cur_loss = loss.item()
|
| 118 |
+
batch_losses.append(cur_loss)
|
| 119 |
+
total_loss += cur_loss
|
| 120 |
+
|
| 121 |
+
if iteration % 100 == 0:
|
| 122 |
+
current = iteration * batch_size + len(batch['label'])
|
| 123 |
+
print(f"loss: {cur_loss:>7f} [{current:>5d}/{size:>5d}]")
|
| 124 |
+
|
| 125 |
+
return total_loss, batch_losses
|
| 126 |
+
|
| 127 |
+
def generate_model_card(save_dir: str, accuracy: float, loss: float, epoch: int):
|
| 128 |
+
with open("classifier/modelcard_template.md", "r") as f:
|
| 129 |
+
template_content = f.read()
|
| 130 |
+
|
| 131 |
+
template = Template(template_content)
|
| 132 |
+
|
| 133 |
+
card_content = template.render(
|
| 134 |
+
model_id=CLASSIFIER_NAME,
|
| 135 |
+
model_summary="A simple medical query triage classifier.",
|
| 136 |
+
model_description="This model classifies queries into 'medical' or 'insurance' categories. It uses EmbeddingGemma-300M as a backbone.",
|
| 137 |
+
developers="David Gray",
|
| 138 |
+
model_type="Text Classification",
|
| 139 |
+
language="en",
|
| 140 |
+
license="mit",
|
| 141 |
+
base_model="sentence-transformers/embeddinggemma-300m-medical",
|
| 142 |
+
repo=f"https://huggingface.co/{CLASSIFIER_NAME}",
|
| 143 |
+
results_summary=f"Epoch: {epoch+1}\nValidation Accuracy: {accuracy*100:.2f}%\nValidation Loss: {loss:.4f}",
|
| 144 |
+
training_data="Miriad (medical) and InsuranceQA (insurance) datasets.",
|
| 145 |
+
testing_metrics="Accuracy, Loss",
|
| 146 |
+
results=f"Accuracy: {accuracy:.4f}, Loss: {loss:.4f}"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
with open(f"{save_dir}/README.md", "w") as f:
|
| 150 |
+
f.write(card_content)
|
| 151 |
+
|
| 152 |
+
def push_model_card(save_dir: str, repo_id: str, token: str = None):
|
| 153 |
+
api = HfApi(token=token)
|
| 154 |
+
api.upload_file(
|
| 155 |
+
path_or_fileobj=f"{save_dir}/README.md",
|
| 156 |
+
path_in_repo="README.md",
|
| 157 |
+
repo_id=repo_id,
|
| 158 |
+
repo_type="model"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def label_to_int(embedding_model, label_names: list):
|
| 162 |
+
"""Creates a dictionary mapping label strings to their integer IDs."""
|
| 163 |
+
label_map = {name: i for i, name in enumerate(label_names)}
|
| 164 |
+
|
| 165 |
+
def collate_fn(batch):
|
| 166 |
+
# 1. Extract texts and labels from the batch (list of dictionaries)
|
| 167 |
+
texts = [item['text'] for item in batch]
|
| 168 |
+
labels = [item['label'] for item in batch]
|
| 169 |
+
|
| 170 |
+
# 2. Tokenize the texts using the embedding model's tokenizer
|
| 171 |
+
# The tokenizer is attached to the embedding_model
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
tokenized_text = embedding_model.encode(
|
| 174 |
+
texts,
|
| 175 |
+
convert_to_tensor=True,
|
| 176 |
+
device=DEVICE
|
| 177 |
+
).clone().detach()
|
| 178 |
+
|
| 179 |
+
# 3. Convert string labels to integers
|
| 180 |
+
int_labels = [label_map[l] for l in labels]
|
| 181 |
+
tokenized_labels = torch.tensor(int_labels, dtype=torch.long)
|
| 182 |
+
|
| 183 |
+
# 4. Add the labels as a PyTorch tensor
|
| 184 |
+
tokenized_batch = {'sentence_embedding': tokenized_text.to(DEVICE), 'label': tokenized_labels.to(DEVICE)}
|
| 185 |
+
|
| 186 |
+
return tokenized_batch
|
| 187 |
+
|
| 188 |
+
return collate_fn
|
| 189 |
+
|
| 190 |
+
def train(push_to_hub: bool = False):
|
| 191 |
+
start_datetime = datetime.now()
|
| 192 |
+
|
| 193 |
+
save_dir = f'{CHECKPOINT_PATH}/{start_datetime.strftime(DATETIME_FORMAT)}'
|
| 194 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 195 |
+
|
| 196 |
+
embedding_model, model, train_ds, test_ds, validation_ds, labels = get_model_train_test()
|
| 197 |
+
batch_size = 64
|
| 198 |
+
custom_collate_fn = label_to_int(embedding_model, labels)
|
| 199 |
+
|
| 200 |
+
train_dataloader = DataLoader(
|
| 201 |
+
train_ds,
|
| 202 |
+
batch_size=batch_size,
|
| 203 |
+
shuffle=True,
|
| 204 |
+
collate_fn=custom_collate_fn
|
| 205 |
+
)
|
| 206 |
+
test_dataloader = DataLoader(
|
| 207 |
+
test_ds,
|
| 208 |
+
batch_size=batch_size,
|
| 209 |
+
shuffle=True,
|
| 210 |
+
collate_fn=custom_collate_fn
|
| 211 |
+
)
|
| 212 |
+
validation_dataloader = DataLoader(
|
| 213 |
+
validation_ds,
|
| 214 |
+
batch_size=batch_size,
|
| 215 |
+
shuffle=True,
|
| 216 |
+
collate_fn=custom_collate_fn
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
loss_fn = model.get_loss_fn()
|
| 220 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
|
| 221 |
+
save_per_epoch = 1
|
| 222 |
+
epochs = 1
|
| 223 |
+
patience = 1
|
| 224 |
+
min_val_loss = float('inf')
|
| 225 |
+
patience_counter = 0
|
| 226 |
+
history = {
|
| 227 |
+
'train_loss_epoch': [],
|
| 228 |
+
'train_loss_batch': [],
|
| 229 |
+
'validation_accuracy': [],
|
| 230 |
+
'validation_loss_epoch': [],
|
| 231 |
+
'test_accuracy': [],
|
| 232 |
+
'test_loss': []
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
for epoch in range(epochs):
|
| 236 |
+
print(f"Epoch {epoch+1}:\n-------------------------------")
|
| 237 |
+
|
| 238 |
+
# Train
|
| 239 |
+
total_loss, batch_losses = train_loop(train_dataloader, model, loss_fn, optimizer)
|
| 240 |
+
avg_epoch_loss = total_loss / len(train_dataloader)
|
| 241 |
+
history['train_loss_epoch'].append(avg_epoch_loss)
|
| 242 |
+
history['train_loss_batch'].extend(batch_losses)
|
| 243 |
+
|
| 244 |
+
summary = f"Epoch {epoch+1}:"
|
| 245 |
+
|
| 246 |
+
# Validate
|
| 247 |
+
val_loss_avg, val_accuracy = test_loop(validation_dataloader, model, loss_fn)
|
| 248 |
+
history['validation_accuracy'].append(val_accuracy)
|
| 249 |
+
history['validation_loss_epoch'].append(val_loss_avg)
|
| 250 |
+
|
| 251 |
+
summary += f" - loss: {avg_epoch_loss}\n"
|
| 252 |
+
summary += f" - training loss: {avg_epoch_loss}\n"
|
| 253 |
+
summary += f" - validation loss: {val_loss_avg:>8f}\n"
|
| 254 |
+
summary += f" - validation accuracy: {(100*val_accuracy):>0.1f}%\n"
|
| 255 |
+
|
| 256 |
+
# Save checkpoint
|
| 257 |
+
if epoch % save_per_epoch == 0:
|
| 258 |
+
# Save model
|
| 259 |
+
model.save_pretrained(save_dir)
|
| 260 |
+
|
| 261 |
+
# Generate and push model card
|
| 262 |
+
# generate_model_card(save_dir, val_accuracy, val_loss_avg, epoch)
|
| 263 |
+
# push_model_card(save_dir, CLASSIFIER_NAME, token=HF_TOKEN)
|
| 264 |
+
|
| 265 |
+
summary += f" -- {save_dir}\n"
|
| 266 |
+
|
| 267 |
+
history_df = pd.DataFrame.from_dict(history, orient='index').transpose()
|
| 268 |
+
history_df.to_csv(f"{save_dir}/history.csv", index=False)
|
| 269 |
+
|
| 270 |
+
# Push model to Hugging Face
|
| 271 |
+
if push_to_hub:
|
| 272 |
+
model.push_to_hub(CLASSIFIER_NAME, token=HF_TOKEN)
|
| 273 |
+
else:
|
| 274 |
+
summary += "\n"
|
| 275 |
+
|
| 276 |
+
print(summary)
|
| 277 |
+
|
| 278 |
+
if val_loss_avg < min_val_loss:
|
| 279 |
+
min_val_loss = val_loss_avg
|
| 280 |
+
patience_counter = 0
|
| 281 |
+
else:
|
| 282 |
+
patience_counter += 1
|
| 283 |
+
if patience_counter >= patience:
|
| 284 |
+
print("Early stopping triggered due to no improvement in validation loss.")
|
| 285 |
+
break
|
| 286 |
+
|
| 287 |
+
# Evaluate on test dataset
|
| 288 |
+
test_loss_avg, test_accuracy = test_loop(test_dataloader, model, loss_fn)
|
| 289 |
+
history['test_accuracy'].append(test_accuracy)
|
| 290 |
+
history['test_loss'].append(test_loss_avg)
|
| 291 |
+
print(f"Test: Accuracy: {(100*test_accuracy):>0.1f}%, Avg loss: {test_loss_avg:>8f}")
|
| 292 |
+
|
| 293 |
+
# Save the final model
|
| 294 |
+
model.save_pretrained(save_dir)
|
| 295 |
+
|
| 296 |
+
# generate_model_card(save_dir, test_accuracy, test_loss_avg, epochs-1)
|
| 297 |
+
# push_model_card(save_dir, CLASSIFIER_NAME, token=HF_TOKEN)
|
| 298 |
+
|
| 299 |
+
# Save loss history
|
| 300 |
+
history_df = pd.DataFrame.from_dict(history, orient='index').transpose()
|
| 301 |
+
history_df.to_csv(f"{save_dir}/history.csv", index=False)
|
| 302 |
+
|
| 303 |
+
# Plot training loss per batch
|
| 304 |
+
fig, ax = plt.subplots()
|
| 305 |
+
ax.plot(history['train_loss_batch'])
|
| 306 |
+
ax.set_title('Training Loss per Batch')
|
| 307 |
+
ax.set_xlabel('Batch')
|
| 308 |
+
ax.set_ylabel('Loss')
|
| 309 |
+
fig.savefig(f"{save_dir}/loss.png")
|
| 310 |
+
|
| 311 |
+
if push_to_hub:
|
| 312 |
+
model.push_to_hub(CLASSIFIER_NAME, token=HF_TOKEN)
|
| 313 |
+
|
| 314 |
+
if __name__ == "__main__":
|
| 315 |
+
ap = argparse.ArgumentParser(
|
| 316 |
+
description="Train a classifier for triaging health queries"
|
| 317 |
+
)
|
| 318 |
+
ap.add_argument(
|
| 319 |
+
"--push", action="store_true",
|
| 320 |
+
help="Push model to Hugging Face"
|
| 321 |
+
)
|
| 322 |
+
args = ap.parse_args()
|
| 323 |
+
|
| 324 |
+
train(push_to_hub=args.push)
|
classifier/utils.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for Healthcare Classification System
|
| 3 |
+
|
| 4 |
+
This module contains shared constants and utilities for the healthcare
|
| 5 |
+
classification system.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from classifier.head import ClassifierHead
|
| 9 |
+
|
| 10 |
+
from classifier.config import load_env
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
from sentence_transformers import SentenceTransformer
|
| 14 |
+
import torch
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
# Load environment variables (including HF_TOKEN)
|
| 19 |
+
load_env()
|
| 20 |
+
|
| 21 |
+
MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
|
| 22 |
+
CLASSIFIER_NAME = "davidgray/health-query-triage"
|
| 23 |
+
CATEGORIES: list[str] = ["medical", "insurance"]
|
| 24 |
+
|
| 25 |
+
# Model and training configuration
|
| 26 |
+
MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
|
| 27 |
+
CHECKPOINT_PATH = "classifier/checkpoints"
|
| 28 |
+
DATETIME_FORMAT = "%Y%m%d_%H%M%S"
|
| 29 |
+
|
| 30 |
+
# Device configuration - use David's newer approach with fallback
|
| 31 |
+
try:
|
| 32 |
+
DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
|
| 33 |
+
except AttributeError:
|
| 34 |
+
# Fallback for older PyTorch versions
|
| 35 |
+
if torch.backends.mps.is_available():
|
| 36 |
+
DEVICE = torch.device("mps")
|
| 37 |
+
elif torch.cuda.is_available():
|
| 38 |
+
DEVICE = torch.device("cuda")
|
| 39 |
+
else:
|
| 40 |
+
DEVICE = torch.device("cpu")
|
| 41 |
+
|
| 42 |
+
print(f"Using {DEVICE} device")
|
| 43 |
+
|
| 44 |
+
def get_models(model_id: str | None = None, num_labels: int = len(CATEGORIES)) -> tuple[SentenceTransformer, ClassifierHead]:
|
| 45 |
+
"""
|
| 46 |
+
Loads embeddinggemma-300m-medical model and initializes the classification head.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
tuple: (embedding_model, classifier_head)
|
| 50 |
+
"""
|
| 51 |
+
try:
|
| 52 |
+
model_body = SentenceTransformer(
|
| 53 |
+
MODEL_NAME,
|
| 54 |
+
prompts={
|
| 55 |
+
'classification': 'task: classification | query: ',
|
| 56 |
+
'retrieval (query)': 'task: search result | query: ',
|
| 57 |
+
'retrieval (document)': 'title: {title | "none"} | text: ',
|
| 58 |
+
},
|
| 59 |
+
default_prompt_name='classification',
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
if model_id:
|
| 63 |
+
model_head = ClassifierHead.from_pretrained(model_id)
|
| 64 |
+
else:
|
| 65 |
+
model_head = ClassifierHead(num_labels)
|
| 66 |
+
|
| 67 |
+
except Exception as e:
|
| 68 |
+
print(f"Error loading model {MODEL_NAME}: {e}")
|
| 69 |
+
print("Please ensure you have an internet connection and the transformers library installed.")
|
| 70 |
+
raise RuntimeError("Failed to load the embedding model.")
|
| 71 |
+
|
| 72 |
+
return model_body.to(DEVICE), model_head.to(DEVICE)
|
| 73 |
+
|
| 74 |
+
def get_latest_checkpoint(checkpoint_path: str):
|
| 75 |
+
return os.path.join(checkpoint_path, sorted(os.listdir(checkpoint_path))[-1])
|
cli/healthcare_classifier_cli.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
End-to-End Healthcare Classification CLI
|
| 3 |
+
|
| 4 |
+
This provides a complete classification pipeline:
|
| 5 |
+
1. First classifies as "medical" or "insurance"
|
| 6 |
+
2. If medical, applies reason classification for detailed categorization
|
| 7 |
+
|
| 8 |
+
IMPORTANT: Activate virtual environment first!
|
| 9 |
+
Usage:
|
| 10 |
+
source .venv/bin/activate
|
| 11 |
+
python cli/healthcare_classifier_cli.py --interactive
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import sys
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
# Add project root to path
|
| 20 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 21 |
+
if str(REPO_ROOT) not in sys.path:
|
| 22 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 23 |
+
|
| 24 |
+
def classify_healthcare_query(query: str):
|
| 25 |
+
"""
|
| 26 |
+
Complete healthcare query classification pipeline.
|
| 27 |
+
|
| 28 |
+
Step 1: Medical vs Insurance classification
|
| 29 |
+
Step 2: If medical, apply reason classification
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
print(f"Query: {query}")
|
| 33 |
+
print("=" * 60)
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
# Add classifier to path
|
| 37 |
+
sys.path.append('classifier')
|
| 38 |
+
|
| 39 |
+
# Step 1: Medical vs Insurance Classification
|
| 40 |
+
print("🔍 Step 1: Medical vs Insurance Classification")
|
| 41 |
+
print("-" * 40)
|
| 42 |
+
|
| 43 |
+
from infer import predict_query
|
| 44 |
+
from utils import get_models
|
| 45 |
+
|
| 46 |
+
# Load medical/insurance classifier
|
| 47 |
+
embedding_model, classifier_head = get_models()
|
| 48 |
+
|
| 49 |
+
# Get medical vs insurance prediction
|
| 50 |
+
result = predict_query([query], embedding_model, classifier_head)
|
| 51 |
+
|
| 52 |
+
primary_category = result['prediction'][0]
|
| 53 |
+
confidence = result['confidence']
|
| 54 |
+
if isinstance(confidence, list):
|
| 55 |
+
confidence = confidence[0]
|
| 56 |
+
|
| 57 |
+
print(f"Primary Classification: {primary_category.upper()}")
|
| 58 |
+
print(f"Confidence: {confidence:.4f}")
|
| 59 |
+
|
| 60 |
+
# Show probabilities
|
| 61 |
+
probabilities = result['probabilities']
|
| 62 |
+
if isinstance(probabilities[0], list):
|
| 63 |
+
probabilities = probabilities[0]
|
| 64 |
+
|
| 65 |
+
print("Probabilities:")
|
| 66 |
+
from utils import CATEGORIES
|
| 67 |
+
for i, category in enumerate(CATEGORIES):
|
| 68 |
+
print(f" {category}: {probabilities[i]:.4f}")
|
| 69 |
+
|
| 70 |
+
# Step 2: If medical, apply reason classification
|
| 71 |
+
if primary_category.lower() == 'medical':
|
| 72 |
+
print(f"\n🏥 Step 2: Medical Reason Classification")
|
| 73 |
+
print("-" * 40)
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
from classifier.reason.infer_reason import predict_single_reason
|
| 77 |
+
|
| 78 |
+
reason_result = predict_single_reason(query)
|
| 79 |
+
|
| 80 |
+
print(f"Medical Reason: {reason_result['category']}")
|
| 81 |
+
print(f"Reason Confidence: {reason_result['confidence']:.4f}")
|
| 82 |
+
|
| 83 |
+
print("Reason Probabilities:")
|
| 84 |
+
sorted_probs = sorted(reason_result['probabilities'].items(),
|
| 85 |
+
key=lambda x: x[1], reverse=True)
|
| 86 |
+
for category, prob in sorted_probs:
|
| 87 |
+
print(f" {category}: {prob:.4f}")
|
| 88 |
+
|
| 89 |
+
# Final routing decision
|
| 90 |
+
print(f"\n🎯 Final Routing Decision")
|
| 91 |
+
print("-" * 25)
|
| 92 |
+
print(f"Route to: {reason_result['category']} Department")
|
| 93 |
+
print(f"Overall confidence: Medical ({confidence:.3f}) → {reason_result['category']} ({reason_result['confidence']:.3f})")
|
| 94 |
+
|
| 95 |
+
return {
|
| 96 |
+
'primary_classification': primary_category,
|
| 97 |
+
'primary_confidence': confidence,
|
| 98 |
+
'reason_classification': reason_result['category'],
|
| 99 |
+
'reason_confidence': reason_result['confidence'],
|
| 100 |
+
'routing': f"{reason_result['category']} Department"
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"⚠️ Reason classification failed: {e}")
|
| 105 |
+
print("Note: Make sure reason classifier is trained")
|
| 106 |
+
print(f"Routing to: General Medical Department")
|
| 107 |
+
|
| 108 |
+
return {
|
| 109 |
+
'primary_classification': primary_category,
|
| 110 |
+
'primary_confidence': confidence,
|
| 111 |
+
'reason_classification': 'GENERAL_MEDICAL',
|
| 112 |
+
'reason_confidence': 0.0,
|
| 113 |
+
'routing': 'General Medical Department'
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
else:
|
| 117 |
+
# Insurance query
|
| 118 |
+
print(f"\n💳 Final Routing Decision")
|
| 119 |
+
print("-" * 25)
|
| 120 |
+
print(f"Route to: Insurance Department")
|
| 121 |
+
print(f"Confidence: {confidence:.3f}")
|
| 122 |
+
|
| 123 |
+
return {
|
| 124 |
+
'primary_classification': primary_category,
|
| 125 |
+
'primary_confidence': confidence,
|
| 126 |
+
'reason_classification': None,
|
| 127 |
+
'reason_confidence': None,
|
| 128 |
+
'routing': 'Insurance Department'
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
print(f"❌ Classification failed: {e}")
|
| 133 |
+
if "No module named 'torch'" in str(e):
|
| 134 |
+
print("\n🔧 SOLUTION:")
|
| 135 |
+
print("You need to activate the virtual environment first!")
|
| 136 |
+
print("Run these commands:")
|
| 137 |
+
print(" source .venv/bin/activate")
|
| 138 |
+
print(" python cli/healthcare_classifier_cli.py --interactive")
|
| 139 |
+
else:
|
| 140 |
+
print("Note: Make sure models are trained and available")
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
def classify_batch_queries(queries_file: str, output_file: str = None):
|
| 144 |
+
"""Process multiple queries through the complete pipeline."""
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
# Read queries
|
| 148 |
+
with open(queries_file, 'r') as f:
|
| 149 |
+
if queries_file.endswith('.json'):
|
| 150 |
+
data = json.load(f)
|
| 151 |
+
if isinstance(data, list):
|
| 152 |
+
queries = data
|
| 153 |
+
else:
|
| 154 |
+
queries = data.get('queries', [])
|
| 155 |
+
else:
|
| 156 |
+
queries = [line.strip() for line in f if line.strip()]
|
| 157 |
+
|
| 158 |
+
print(f"Processing {len(queries)} queries through complete pipeline...")
|
| 159 |
+
print("=" * 60)
|
| 160 |
+
|
| 161 |
+
results = []
|
| 162 |
+
for i, query in enumerate(queries, 1):
|
| 163 |
+
print(f"\n📋 Query {i}/{len(queries)}")
|
| 164 |
+
result = classify_healthcare_query(query)
|
| 165 |
+
if result:
|
| 166 |
+
result['query'] = query
|
| 167 |
+
results.append(result)
|
| 168 |
+
print()
|
| 169 |
+
|
| 170 |
+
# Save results if output file specified
|
| 171 |
+
if output_file:
|
| 172 |
+
output_data = {
|
| 173 |
+
'queries': queries,
|
| 174 |
+
'predictions': results,
|
| 175 |
+
'summary': {
|
| 176 |
+
'total_queries': len(queries),
|
| 177 |
+
'medical_queries': len([r for r in results if r['primary_classification'].lower() == 'medical']),
|
| 178 |
+
'insurance_queries': len([r for r in results if r['primary_classification'].lower() == 'insurance']),
|
| 179 |
+
'reason_categories': {}
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
# Count reason categories
|
| 184 |
+
for result in results:
|
| 185 |
+
if result['reason_classification']:
|
| 186 |
+
cat = result['reason_classification']
|
| 187 |
+
output_data['summary']['reason_categories'][cat] = output_data['summary']['reason_categories'].get(cat, 0) + 1
|
| 188 |
+
|
| 189 |
+
with open(output_file, 'w') as f:
|
| 190 |
+
json.dump(output_data, f, indent=2)
|
| 191 |
+
|
| 192 |
+
print(f"📄 Results saved to {output_file}")
|
| 193 |
+
|
| 194 |
+
# Show summary
|
| 195 |
+
medical_count = len([r for r in results if r['primary_classification'].lower() == 'medical'])
|
| 196 |
+
insurance_count = len([r for r in results if r['primary_classification'].lower() == 'insurance'])
|
| 197 |
+
|
| 198 |
+
print(f"\n📊 Summary:")
|
| 199 |
+
print(f" Medical queries: {medical_count} ({medical_count/len(results)*100:.1f}%)")
|
| 200 |
+
print(f" Insurance queries: {insurance_count} ({insurance_count/len(results)*100:.1f}%)")
|
| 201 |
+
|
| 202 |
+
if medical_count > 0:
|
| 203 |
+
reason_counts = {}
|
| 204 |
+
for result in results:
|
| 205 |
+
if result['reason_classification']:
|
| 206 |
+
cat = result['reason_classification']
|
| 207 |
+
reason_counts[cat] = reason_counts.get(cat, 0) + 1
|
| 208 |
+
|
| 209 |
+
print(f"\n Medical reason breakdown:")
|
| 210 |
+
for category, count in sorted(reason_counts.items()):
|
| 211 |
+
percentage = (count / medical_count) * 100
|
| 212 |
+
print(f" {category}: {count} queries ({percentage:.1f}%)")
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
print(f"❌ Error processing batch queries: {e}")
|
| 216 |
+
return False
|
| 217 |
+
|
| 218 |
+
return True
|
| 219 |
+
|
| 220 |
+
def interactive_mode():
|
| 221 |
+
"""Interactive mode for complete healthcare classification."""
|
| 222 |
+
|
| 223 |
+
print("🏥 Complete Healthcare Classification System")
|
| 224 |
+
print("=" * 50)
|
| 225 |
+
print("This system provides end-to-end classification:")
|
| 226 |
+
print(" 1️⃣ Medical vs Insurance classification")
|
| 227 |
+
print(" 2️⃣ Medical reason classification (if medical)")
|
| 228 |
+
print(" 3️⃣ Final routing decision")
|
| 229 |
+
print()
|
| 230 |
+
print("Enter healthcare queries to classify (type 'quit' to exit)")
|
| 231 |
+
print()
|
| 232 |
+
print("Example queries to try:")
|
| 233 |
+
print(" Medical: 'I have heel pain when I walk'")
|
| 234 |
+
print(" Medical: 'I need routine foot care'")
|
| 235 |
+
print(" Medical: 'I sprained my ankle'")
|
| 236 |
+
print(" Insurance: 'My insurance claim was denied'")
|
| 237 |
+
print(" Insurance: 'What does my insurance cover?'")
|
| 238 |
+
print()
|
| 239 |
+
|
| 240 |
+
while True:
|
| 241 |
+
try:
|
| 242 |
+
user_input = input("🔍 Enter query >>> ").strip()
|
| 243 |
+
|
| 244 |
+
if user_input.lower() == 'quit':
|
| 245 |
+
print("👋 Goodbye!")
|
| 246 |
+
break
|
| 247 |
+
|
| 248 |
+
if user_input:
|
| 249 |
+
classify_healthcare_query(user_input)
|
| 250 |
+
print("\n" + "="*60)
|
| 251 |
+
|
| 252 |
+
except KeyboardInterrupt:
|
| 253 |
+
print("\n👋 Goodbye!")
|
| 254 |
+
break
|
| 255 |
+
except Exception as e:
|
| 256 |
+
print(f"❌ Error: {e}")
|
| 257 |
+
print()
|
| 258 |
+
|
| 259 |
+
def main():
|
| 260 |
+
parser = argparse.ArgumentParser(
|
| 261 |
+
description='Complete Healthcare Classification CLI',
|
| 262 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 263 |
+
epilog="""
|
| 264 |
+
Examples:
|
| 265 |
+
# Interactive mode (recommended)
|
| 266 |
+
python cli/healthcare_classifier_cli.py --interactive
|
| 267 |
+
|
| 268 |
+
# Classify a single query
|
| 269 |
+
python cli/healthcare_classifier_cli.py "I have heel pain"
|
| 270 |
+
|
| 271 |
+
# Batch process queries from file
|
| 272 |
+
python cli/healthcare_classifier_cli.py --batch queries.txt --output results.json
|
| 273 |
+
|
| 274 |
+
Pipeline:
|
| 275 |
+
Query → Medical/Insurance → (if Medical) → Reason Classification → Routing
|
| 276 |
+
"""
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
parser.add_argument('query', nargs='?', help='Healthcare query to classify')
|
| 280 |
+
parser.add_argument('--batch', type=str, help='File containing queries to process')
|
| 281 |
+
parser.add_argument('--output', type=str, help='Output file for batch results')
|
| 282 |
+
parser.add_argument('--interactive', action='store_true',
|
| 283 |
+
help='Start interactive mode (recommended)')
|
| 284 |
+
|
| 285 |
+
args = parser.parse_args()
|
| 286 |
+
|
| 287 |
+
# Interactive mode
|
| 288 |
+
if args.interactive:
|
| 289 |
+
interactive_mode()
|
| 290 |
+
return 0
|
| 291 |
+
|
| 292 |
+
# Batch processing
|
| 293 |
+
if args.batch:
|
| 294 |
+
if not Path(args.batch).exists():
|
| 295 |
+
print(f"❌ Error: Batch file does not exist: {args.batch}")
|
| 296 |
+
return 1
|
| 297 |
+
|
| 298 |
+
success = classify_batch_queries(args.batch, args.output)
|
| 299 |
+
return 0 if success else 1
|
| 300 |
+
|
| 301 |
+
# Single query processing
|
| 302 |
+
if args.query:
|
| 303 |
+
result = classify_healthcare_query(args.query)
|
| 304 |
+
return 0 if result else 1
|
| 305 |
+
|
| 306 |
+
# No arguments provided - show help and suggest interactive mode
|
| 307 |
+
print("🏥 Complete Healthcare Classification System")
|
| 308 |
+
print("=" * 45)
|
| 309 |
+
print("IMPORTANT: Activate virtual environment first!")
|
| 310 |
+
print(" source .venv/bin/activate")
|
| 311 |
+
print(" python cli/healthcare_classifier_cli.py --interactive")
|
| 312 |
+
print()
|
| 313 |
+
parser.print_help()
|
| 314 |
+
return 1
|
| 315 |
+
|
| 316 |
+
if __name__ == "__main__":
|
| 317 |
+
sys.exit(main())
|
cli/reason_classifier_cli.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CLI Interface for Healthcare Reason Classification
|
| 3 |
+
|
| 4 |
+
This provides a command-line interface for testing and using the
|
| 5 |
+
healthcare reason classifier system with real healthcare data.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# Add project root to path
|
| 14 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 15 |
+
if str(REPO_ROOT) not in sys.path:
|
| 16 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 17 |
+
|
| 18 |
+
def classify_single_query(query: str):
|
| 19 |
+
"""Classify a single healthcare reason query and display results."""
|
| 20 |
+
|
| 21 |
+
print(f"Query: {query}")
|
| 22 |
+
print("-" * 50)
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
# Import the reason inference module
|
| 26 |
+
sys.path.append('classifier')
|
| 27 |
+
from classifier.reason.infer_reason import predict_single_reason
|
| 28 |
+
|
| 29 |
+
# Get prediction
|
| 30 |
+
result = predict_single_reason(query)
|
| 31 |
+
|
| 32 |
+
print(f"Primary Classification: {result['category']}")
|
| 33 |
+
print(f"Confidence: {result['confidence']:.4f}")
|
| 34 |
+
|
| 35 |
+
# Show all category probabilities
|
| 36 |
+
print(f"\nAll Category Probabilities:")
|
| 37 |
+
|
| 38 |
+
# Sort by probability
|
| 39 |
+
sorted_probs = sorted(result['probabilities'].items(),
|
| 40 |
+
key=lambda x: x[1], reverse=True)
|
| 41 |
+
|
| 42 |
+
for category, prob in sorted_probs:
|
| 43 |
+
print(f" {category}: {prob:.4f}")
|
| 44 |
+
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"Error: {e}")
|
| 47 |
+
print("Note: Make sure the reason classifier is trained")
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
return True
|
| 51 |
+
|
| 52 |
+
def classify_batch_queries(queries_file: str, output_file: str = None):
|
| 53 |
+
"""Classify multiple queries from a file."""
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
# Read queries
|
| 57 |
+
with open(queries_file, 'r') as f:
|
| 58 |
+
if queries_file.endswith('.json'):
|
| 59 |
+
data = json.load(f)
|
| 60 |
+
if isinstance(data, list):
|
| 61 |
+
queries = data
|
| 62 |
+
else:
|
| 63 |
+
queries = data.get('queries', [])
|
| 64 |
+
else:
|
| 65 |
+
queries = [line.strip() for line in f if line.strip()]
|
| 66 |
+
|
| 67 |
+
print(f"Processing {len(queries)} healthcare reason queries...")
|
| 68 |
+
|
| 69 |
+
# Import the reason inference module
|
| 70 |
+
sys.path.append('classifier')
|
| 71 |
+
from classifier.reason.infer_reason import predict_single_reason
|
| 72 |
+
|
| 73 |
+
results = []
|
| 74 |
+
for i, query in enumerate(queries, 1):
|
| 75 |
+
print(f"\n{i}. Query: {query}")
|
| 76 |
+
|
| 77 |
+
result = predict_single_reason(query)
|
| 78 |
+
results.append(result)
|
| 79 |
+
|
| 80 |
+
print(f" Category: {result['category']} (confidence: {result['confidence']:.3f})")
|
| 81 |
+
|
| 82 |
+
# Save results if output file specified
|
| 83 |
+
if output_file:
|
| 84 |
+
output_data = {
|
| 85 |
+
'queries': queries,
|
| 86 |
+
'predictions': results,
|
| 87 |
+
'summary': {
|
| 88 |
+
'total_queries': len(queries),
|
| 89 |
+
'categories': {}
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
# Count categories
|
| 94 |
+
for result in results:
|
| 95 |
+
cat = result['category']
|
| 96 |
+
output_data['summary']['categories'][cat] = output_data['summary']['categories'].get(cat, 0) + 1
|
| 97 |
+
|
| 98 |
+
with open(output_file, 'w') as f:
|
| 99 |
+
json.dump(output_data, f, indent=2)
|
| 100 |
+
|
| 101 |
+
print(f"\nResults saved to {output_file}")
|
| 102 |
+
|
| 103 |
+
# Show summary
|
| 104 |
+
category_counts = {}
|
| 105 |
+
for result in results:
|
| 106 |
+
cat = result['category']
|
| 107 |
+
category_counts[cat] = category_counts.get(cat, 0) + 1
|
| 108 |
+
|
| 109 |
+
print(f"\nSummary:")
|
| 110 |
+
for category, count in sorted(category_counts.items()):
|
| 111 |
+
percentage = (count / len(queries)) * 100
|
| 112 |
+
print(f" {category}: {count} queries ({percentage:.1f}%)")
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
print(f"Error processing batch queries: {e}")
|
| 116 |
+
return False
|
| 117 |
+
|
| 118 |
+
return True
|
| 119 |
+
|
| 120 |
+
def interactive_mode():
|
| 121 |
+
"""Interactive mode for testing healthcare reason queries."""
|
| 122 |
+
|
| 123 |
+
print("Healthcare Reason Classifier - Interactive Mode")
|
| 124 |
+
print("=" * 50)
|
| 125 |
+
print("Enter healthcare reason queries to classify (type 'quit' to exit)")
|
| 126 |
+
print()
|
| 127 |
+
print("Example queries to try:")
|
| 128 |
+
print(" • 'I have heel pain when I walk'")
|
| 129 |
+
print(" • 'My toenail is ingrown and infected'")
|
| 130 |
+
print(" • 'I need routine foot care'")
|
| 131 |
+
print(" • 'I sprained my ankle playing basketball'")
|
| 132 |
+
print(" • 'I have plantar fasciitis'")
|
| 133 |
+
print(" • 'I need a cortisone injection'")
|
| 134 |
+
print()
|
| 135 |
+
|
| 136 |
+
while True:
|
| 137 |
+
try:
|
| 138 |
+
user_input = input(">>> ").strip()
|
| 139 |
+
|
| 140 |
+
if user_input.lower() == 'quit':
|
| 141 |
+
break
|
| 142 |
+
|
| 143 |
+
if user_input:
|
| 144 |
+
classify_single_query(user_input)
|
| 145 |
+
print()
|
| 146 |
+
|
| 147 |
+
except KeyboardInterrupt:
|
| 148 |
+
break
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"Error: {e}")
|
| 151 |
+
print()
|
| 152 |
+
|
| 153 |
+
def main():
|
| 154 |
+
parser = argparse.ArgumentParser(
|
| 155 |
+
description='Healthcare Reason Classification CLI',
|
| 156 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 157 |
+
epilog="""
|
| 158 |
+
Examples:
|
| 159 |
+
# Classify a single healthcare reason query
|
| 160 |
+
python cli/reason_classifier_cli_new.py "I have heel pain"
|
| 161 |
+
|
| 162 |
+
# Batch process queries from file
|
| 163 |
+
python cli/reason_classifier_cli_new.py --batch reason_queries.txt --output results.json
|
| 164 |
+
|
| 165 |
+
# Interactive mode
|
| 166 |
+
python cli/reason_classifier_cli_new.py --interactive
|
| 167 |
+
"""
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
parser.add_argument('query', nargs='?', help='Healthcare reason query to classify')
|
| 171 |
+
parser.add_argument('--batch', type=str, help='File containing queries to process')
|
| 172 |
+
parser.add_argument('--output', type=str, help='Output file for batch results')
|
| 173 |
+
parser.add_argument('--interactive', action='store_true',
|
| 174 |
+
help='Start interactive mode')
|
| 175 |
+
|
| 176 |
+
args = parser.parse_args()
|
| 177 |
+
|
| 178 |
+
# Interactive mode
|
| 179 |
+
if args.interactive:
|
| 180 |
+
interactive_mode()
|
| 181 |
+
return 0
|
| 182 |
+
|
| 183 |
+
# Batch processing
|
| 184 |
+
if args.batch:
|
| 185 |
+
if not Path(args.batch).exists():
|
| 186 |
+
print(f"Error: Batch file does not exist: {args.batch}")
|
| 187 |
+
return 1
|
| 188 |
+
|
| 189 |
+
success = classify_batch_queries(args.batch, args.output)
|
| 190 |
+
return 0 if success else 1
|
| 191 |
+
|
| 192 |
+
# Single query processing
|
| 193 |
+
if args.query:
|
| 194 |
+
success = classify_single_query(args.query)
|
| 195 |
+
return 0 if success else 1
|
| 196 |
+
|
| 197 |
+
# No arguments provided
|
| 198 |
+
parser.print_help()
|
| 199 |
+
return 1
|
| 200 |
+
|
| 201 |
+
if __name__ == "__main__":
|
| 202 |
+
sys.exit(main())
|
config.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict, List
|
| 4 |
+
import torch
|
| 5 |
+
from pydantic_settings import BaseSettings
|
| 6 |
+
|
| 7 |
+
class Settings(BaseSettings):
|
| 8 |
+
# Model Configuration
|
| 9 |
+
MODEL_NAME: str = "sentence-transformers/embeddinggemma-300m-medical"
|
| 10 |
+
CLASSIFIER_NAME: str = "davidgray/health-query-triage"
|
| 11 |
+
CATEGORIES: List[str] = ["medical", "insurance"]
|
| 12 |
+
|
| 13 |
+
# Paths
|
| 14 |
+
CHECKPOINT_PATH: str = "classifier/checkpoints"
|
| 15 |
+
CACHE_DIR: str = ".cache/embeddings"
|
| 16 |
+
|
| 17 |
+
# Device
|
| 18 |
+
DEVICE: str = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
| 19 |
+
|
| 20 |
+
# Corpora Configuration
|
| 21 |
+
CORPORA_CONFIG: Dict[str, dict] = {
|
| 22 |
+
"medical_qa": {"path": "data/corpora/medical_qa.jsonl",
|
| 23 |
+
"text_fields": ["question", "answer", "title"]},
|
| 24 |
+
"miriad": {"path": "data/corpora/miriad_text.jsonl",
|
| 25 |
+
"text_fields": ["question", "answer", "title"]},
|
| 26 |
+
"pubmed": {"path": "data/corpora/pubmed.json",
|
| 27 |
+
"text_fields": ["contents","title"]},
|
| 28 |
+
"unidoc": {"path": "data/corpora/unidoc_qa.jsonl",
|
| 29 |
+
"text_fields": ["question", "answer", "title"]},
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
class Config:
|
| 33 |
+
env_file = ".env"
|
| 34 |
+
|
| 35 |
+
settings = Settings()
|
environment.yml
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: cs410-group-proj
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
- conda-forge
|
| 5 |
+
- pytorch
|
| 6 |
+
- huggingface
|
| 7 |
+
- anaconda
|
| 8 |
+
dependencies:
|
| 9 |
+
- blas=1.0=openblas
|
| 10 |
+
- brotli-python=1.0.9=py311h313beb8_9
|
| 11 |
+
- bzip2=1.0.8=hd037594_8
|
| 12 |
+
- ca-certificates=2025.10.5=hbd8a1cb_0
|
| 13 |
+
- certifi=2025.10.5=py311hca03da5_0
|
| 14 |
+
- cffi=2.0.0=py311h3a083c1_0
|
| 15 |
+
- cryptography=44.0.1=py311h8026fc7_0
|
| 16 |
+
- faiss-cpu=1.12.0=py3.11_hcb8d3e5_0_cpu
|
| 17 |
+
- gmp=6.3.0=h313beb8_0
|
| 18 |
+
- gmpy2=2.2.1=py311h5c1b81f_0
|
| 19 |
+
- huggingface_hub=0.29.2=py_0
|
| 20 |
+
- icu=75.1=hfee45f7_0
|
| 21 |
+
- jinja2=3.1.6=py311hca03da5_0
|
| 22 |
+
- libcxx=20.1.8=h8869778_0
|
| 23 |
+
- libexpat=2.7.1=hec049ff_0
|
| 24 |
+
- libfaiss=1.12.0=py3.11_hcb8d3e5_0_cpu
|
| 25 |
+
- libffi=3.4.6=h1da3d7d_1
|
| 26 |
+
- libgfortran=5.0.0=11_3_0_hca03da5_28
|
| 27 |
+
- libgfortran5=11.3.0=h009349e_28
|
| 28 |
+
- libidn2=2.3.4=h80987f9_0
|
| 29 |
+
- liblzma=5.8.1=h39f12f2_2
|
| 30 |
+
- libopenblas=0.3.30=hf2bb037_0
|
| 31 |
+
- libsqlite=3.50.4=h4237e3c_0
|
| 32 |
+
- libunistring=0.9.10=h1a28f6b_0
|
| 33 |
+
- libzlib=1.3.1=h5f15de7_0
|
| 34 |
+
- llvm-openmp=20.1.8=he822017_0
|
| 35 |
+
- maven=3.9.11=hce30654_0
|
| 36 |
+
- mpc=1.3.1=h80987f9_0
|
| 37 |
+
- mpfr=4.2.1=h80987f9_0
|
| 38 |
+
- mpmath=1.3.0=py311hca03da5_0
|
| 39 |
+
- ncurses=6.5=h5e97a16_3
|
| 40 |
+
- networkx=3.5=py311hca03da5_0
|
| 41 |
+
- nomkl=3.0=0
|
| 42 |
+
- numpy=1.26.4=py311h901140f_1
|
| 43 |
+
- numpy-base=1.26.4=py311hae06d03_1
|
| 44 |
+
- openblas=0.3.30=hb03180a_0
|
| 45 |
+
- openblas-devel=0.3.30=h1465027_0
|
| 46 |
+
- openjdk=21.0.8=h55d13f6_0
|
| 47 |
+
- openssl=3.5.4=h5503f6c_0
|
| 48 |
+
- packaging=25.0=py311hca03da5_0
|
| 49 |
+
- pip=25.2=pyh8b19718_0
|
| 50 |
+
- pycparser=2.23=py311hca03da5_0
|
| 51 |
+
- pyopenssl=25.0.0=py311h9e2d7d8_0
|
| 52 |
+
- pysocks=1.7.1=py311hca03da5_0
|
| 53 |
+
- python=3.11.14=hec0b533_1_cpython
|
| 54 |
+
- python_abi=3.11=1_cp311
|
| 55 |
+
- pytorch=2.2.2=py3.11_0
|
| 56 |
+
- readline=8.2=h1d1bf99_2
|
| 57 |
+
- requests=2.32.5=py311hca03da5_0
|
| 58 |
+
- setuptools=80.9.0=pyhff2d567_0
|
| 59 |
+
- sympy=1.14.0=py311hca03da5_0
|
| 60 |
+
- tk=8.6.13=h892fb3f_2
|
| 61 |
+
- tqdm=4.67.1=py311hb6e6a13_0
|
| 62 |
+
- typing-extensions=4.15.0=py311hca03da5_0
|
| 63 |
+
- typing_extensions=4.15.0=py311hca03da5_0
|
| 64 |
+
- wget=1.24.5=h3e2b118_0
|
| 65 |
+
- wheel=0.45.1=pyhd8ed1ab_1
|
| 66 |
+
- yaml=0.2.5=h1a28f6b_0
|
| 67 |
+
- zlib=1.3.1=h5f15de7_0
|
| 68 |
+
- pip:
|
| 69 |
+
- aiohappyeyeballs==2.6.1
|
| 70 |
+
- aiohttp==3.13.1
|
| 71 |
+
- aiosignal==1.4.0
|
| 72 |
+
- annotated-types==0.7.0
|
| 73 |
+
- anyio==4.11.0
|
| 74 |
+
- attrs==25.4.0
|
| 75 |
+
- blinker==1.9.0
|
| 76 |
+
- blis==1.3.0
|
| 77 |
+
- catalogue==2.0.10
|
| 78 |
+
- charset-normalizer==3.4.4
|
| 79 |
+
- click==8.3.0
|
| 80 |
+
- cloudpathlib==0.23.0
|
| 81 |
+
- coloredlogs==15.0.1
|
| 82 |
+
- confection==0.1.5
|
| 83 |
+
- cymem==2.0.11
|
| 84 |
+
- cython==3.1.4
|
| 85 |
+
- datasets==2.13.2
|
| 86 |
+
- dill==0.3.6
|
| 87 |
+
- distro==1.9.0
|
| 88 |
+
- fastapi==0.119.0
|
| 89 |
+
- filelock==3.20.0
|
| 90 |
+
- flask==3.1.2
|
| 91 |
+
- flatbuffers==25.9.23
|
| 92 |
+
- frozenlist==1.8.0
|
| 93 |
+
- fsspec==2025.9.0
|
| 94 |
+
- h11==0.16.0
|
| 95 |
+
- hf-xet==1.1.10
|
| 96 |
+
- httpcore==1.0.9
|
| 97 |
+
- httpx==0.28.1
|
| 98 |
+
- httpx-sse==0.4.3
|
| 99 |
+
- huggingface-hub==0.35.3
|
| 100 |
+
- humanfriendly==10.0
|
| 101 |
+
- idna==3.11
|
| 102 |
+
- itsdangerous==2.2.0
|
| 103 |
+
- jiter==0.11.1
|
| 104 |
+
- joblib==1.5.2
|
| 105 |
+
- jsonschema==4.25.1
|
| 106 |
+
- jsonschema-specifications==2025.9.1
|
| 107 |
+
- langcodes==3.5.0
|
| 108 |
+
- language-data==1.3.0
|
| 109 |
+
- marisa-trie==1.3.1
|
| 110 |
+
- markdown-it-py==4.0.0
|
| 111 |
+
- markupsafe==3.0.3
|
| 112 |
+
- mcp==1.18.0
|
| 113 |
+
- mdurl==0.1.2
|
| 114 |
+
- multidict==6.7.0
|
| 115 |
+
- multiprocess==0.70.14
|
| 116 |
+
- murmurhash==1.0.13
|
| 117 |
+
- onnxruntime==1.23.1
|
| 118 |
+
- openai==2.5.0
|
| 119 |
+
- pandas==2.3.3
|
| 120 |
+
- pillow==12.0.0
|
| 121 |
+
- preshed==3.0.10
|
| 122 |
+
- propcache==0.4.1
|
| 123 |
+
- protobuf==6.33.0
|
| 124 |
+
- pyarrow==11.0.0
|
| 125 |
+
- pybind11==3.0.1
|
| 126 |
+
- pydantic==2.12.3
|
| 127 |
+
- pydantic-core==2.41.4
|
| 128 |
+
- pydantic-settings==2.11.0
|
| 129 |
+
- pygments==2.19.2
|
| 130 |
+
- pyjnius==1.7.0
|
| 131 |
+
- pyserini==1.2.0
|
| 132 |
+
- python-dateutil==2.9.0.post0
|
| 133 |
+
- python-dotenv==1.1.1
|
| 134 |
+
- python-multipart==0.0.20
|
| 135 |
+
- pytz==2025.2
|
| 136 |
+
- pyyaml==6.0.3
|
| 137 |
+
- referencing==0.37.0
|
| 138 |
+
- regex==2025.9.18
|
| 139 |
+
- rich==14.2.0
|
| 140 |
+
- rpds-py==0.27.1
|
| 141 |
+
- safetensors==0.6.2
|
| 142 |
+
- scikit-learn==1.7.2
|
| 143 |
+
- scipy==1.16.2
|
| 144 |
+
- sentencepiece==0.2.1
|
| 145 |
+
- shellingham==1.5.4
|
| 146 |
+
- six==1.17.0
|
| 147 |
+
- smart-open==7.3.1
|
| 148 |
+
- sniffio==1.3.1
|
| 149 |
+
- spacy==3.8.7
|
| 150 |
+
- spacy-legacy==3.0.12
|
| 151 |
+
- spacy-loggers==1.0.5
|
| 152 |
+
- srsly==2.5.1
|
| 153 |
+
- sse-starlette==3.0.2
|
| 154 |
+
- starlette==0.48.0
|
| 155 |
+
- thinc==8.3.6
|
| 156 |
+
- threadpoolctl==3.6.0
|
| 157 |
+
- tiktoken==0.12.0
|
| 158 |
+
- tokenizers==0.22.1
|
| 159 |
+
- torch==2.9.0
|
| 160 |
+
- torchaudio==2.9.0
|
| 161 |
+
- torchvision==0.24.0
|
| 162 |
+
- transformers==4.57.1
|
| 163 |
+
- typer==0.19.2
|
| 164 |
+
- typing-inspection==0.4.2
|
| 165 |
+
- tzdata==2025.2
|
| 166 |
+
- urllib3==2.5.0
|
| 167 |
+
- uvicorn==0.38.0
|
| 168 |
+
- wasabi==1.1.3
|
| 169 |
+
- weasel==0.4.1
|
| 170 |
+
- werkzeug==3.1.3
|
| 171 |
+
- wrapt==2.0.0
|
| 172 |
+
- xxhash==3.6.0
|
| 173 |
+
- yarl==1.22.0
|
| 174 |
+
prefix: /opt/homebrew/Caskroom/miniconda/base/envs/cs410-group-proj
|
launch_ui.bat
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
REM Medical Q&A Bot - Easy Launcher
|
| 3 |
+
REM Double-click this file to launch the web UI
|
| 4 |
+
|
| 5 |
+
echo ========================================
|
| 6 |
+
echo Medical Q&A Bot - Web Interface
|
| 7 |
+
echo ========================================
|
| 8 |
+
echo.
|
| 9 |
+
|
| 10 |
+
REM Check if virtual environment exists
|
| 11 |
+
if exist ".venv\Scripts\activate.bat" (
|
| 12 |
+
echo Activating virtual environment...
|
| 13 |
+
call .venv\Scripts\activate.bat
|
| 14 |
+
) else (
|
| 15 |
+
echo No virtual environment found. Using system Python.
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
echo.
|
| 19 |
+
echo Launching Gradio interface...
|
| 20 |
+
echo The web UI will open at: http://127.0.0.1:7860
|
| 21 |
+
echo.
|
| 22 |
+
echo Press Ctrl+C to stop the server
|
| 23 |
+
echo ========================================
|
| 24 |
+
echo.
|
| 25 |
+
|
| 26 |
+
python app.py
|
| 27 |
+
|
| 28 |
+
pause
|
launch_ui.ps1
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Medical Q&A Bot - PowerShell Launcher
|
| 2 |
+
# Run this script to launch the web UI
|
| 3 |
+
|
| 4 |
+
Write-Host "========================================" -ForegroundColor Cyan
|
| 5 |
+
Write-Host "Medical Q&A Bot - Web Interface" -ForegroundColor Cyan
|
| 6 |
+
Write-Host "========================================" -ForegroundColor Cyan
|
| 7 |
+
Write-Host ""
|
| 8 |
+
|
| 9 |
+
# Check if virtual environment exists
|
| 10 |
+
if (Test-Path ".venv\Scripts\Activate.ps1") {
|
| 11 |
+
Write-Host "Activating virtual environment..." -ForegroundColor Green
|
| 12 |
+
& .venv\Scripts\Activate.ps1
|
| 13 |
+
} else {
|
| 14 |
+
Write-Host "No virtual environment found. Using system Python." -ForegroundColor Yellow
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
Write-Host ""
|
| 18 |
+
Write-Host "Launching Gradio interface..." -ForegroundColor Green
|
| 19 |
+
Write-Host "The web UI will open at: http://127.0.0.1:7860" -ForegroundColor Yellow
|
| 20 |
+
Write-Host ""
|
| 21 |
+
Write-Host "Press Ctrl+C to stop the server" -ForegroundColor Red
|
| 22 |
+
Write-Host "========================================" -ForegroundColor Cyan
|
| 23 |
+
Write-Host ""
|
| 24 |
+
|
| 25 |
+
# Launch the app
|
| 26 |
+
python app.py
|
| 27 |
+
|
| 28 |
+
# Keep window open on error
|
| 29 |
+
if ($LASTEXITCODE -ne 0) {
|
| 30 |
+
Write-Host ""
|
| 31 |
+
Write-Host "An error occurred. Press any key to exit..." -ForegroundColor Red
|
| 32 |
+
$null = $Host.UI.RawUI.ReadKey("NoEcho,IncludeKeyDown")
|
| 33 |
+
}
|
main.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from dataclasses import asdict
|
| 4 |
+
|
| 5 |
+
from pipeline import HealthQueryPipeline
|
| 6 |
+
|
| 7 |
+
EXIT_COMMANDS = ["exit", "quit"]
|
| 8 |
+
PROMPT = "\nQuery> "
|
| 9 |
+
|
| 10 |
+
def main(pipeline: HealthQueryPipeline, k: int) -> None:
|
| 11 |
+
print(f"(Ctrl-D or 'quit' to exit)")
|
| 12 |
+
|
| 13 |
+
while True:
|
| 14 |
+
try:
|
| 15 |
+
query = input(PROMPT).strip()
|
| 16 |
+
if not query or query.lower() in EXIT_COMMANDS:
|
| 17 |
+
break
|
| 18 |
+
|
| 19 |
+
# Show index status
|
| 20 |
+
curr, total = pipeline.get_index_progress()
|
| 21 |
+
if total > 0:
|
| 22 |
+
pct = int((curr / total) * 100)
|
| 23 |
+
if pct < 100:
|
| 24 |
+
print(f"[Index: {pct}% loaded]")
|
| 25 |
+
|
| 26 |
+
# Use the pipeline to get results
|
| 27 |
+
result = pipeline.predict(query, k=k)
|
| 28 |
+
|
| 29 |
+
classification = result["classification"]
|
| 30 |
+
prediction = classification["prediction"]
|
| 31 |
+
|
| 32 |
+
print(f"\nTriaging query as {prediction}")
|
| 33 |
+
print(f"\nConfidence:")
|
| 34 |
+
for cat, prob in classification["probabilities"].items():
|
| 35 |
+
percent = prob * 100
|
| 36 |
+
print(f" {cat}: {percent:3.2f}%")
|
| 37 |
+
print()
|
| 38 |
+
|
| 39 |
+
if "medical" == prediction:
|
| 40 |
+
hits = result["retrieval"]
|
| 41 |
+
print(f"Found {len(hits)} matching medical documents\n")
|
| 42 |
+
|
| 43 |
+
if not hits:
|
| 44 |
+
print("No medical documents found.\n")
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
for i, hit in enumerate(hits, 1):
|
| 48 |
+
# hit is already a dict from the pipeline
|
| 49 |
+
print(json.dumps(hit, indent=2, ensure_ascii=False))
|
| 50 |
+
else:
|
| 51 |
+
print(f"TODO: handle queries of type {prediction}")
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
except EOFError:
|
| 55 |
+
print("\nBye!")
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
except KeyboardInterrupt:
|
| 59 |
+
print("\nBye!")
|
| 60 |
+
break
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
ap = argparse.ArgumentParser(
|
| 65 |
+
description="Hybrid retrieval (BM25 + Dense + RRF, optional re-rank)"
|
| 66 |
+
)
|
| 67 |
+
ap.add_argument("--k", type=int, default=10, help="Number of results to return")
|
| 68 |
+
ap.add_argument(
|
| 69 |
+
"--rerank", action="store_true",
|
| 70 |
+
help="Use cross-encoder reranker (slower, usually better)"
|
| 71 |
+
)
|
| 72 |
+
args = ap.parse_args()
|
| 73 |
+
|
| 74 |
+
# Initialize pipeline
|
| 75 |
+
pipeline = HealthQueryPipeline(use_reranker=args.rerank)
|
| 76 |
+
pipeline.initialize()
|
| 77 |
+
|
| 78 |
+
main(pipeline, k=args.k)
|
pipeline.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from dataclasses import asdict
|
| 3 |
+
from typing import List, Dict, Any, Optional
|
| 4 |
+
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
from classifier.head import ClassifierHead
|
| 7 |
+
from classifier.infer import predict_query
|
| 8 |
+
from classifier.utils import get_models
|
| 9 |
+
from retriever import Retriever
|
| 10 |
+
from team.candidates import get_candidates, _available
|
| 11 |
+
from config import settings
|
| 12 |
+
|
| 13 |
+
class HealthQueryPipeline:
|
| 14 |
+
def __init__(self, use_reranker: bool = False):
|
| 15 |
+
self.use_reranker = use_reranker
|
| 16 |
+
self.embedding_model: Optional[SentenceTransformer] = None
|
| 17 |
+
self.classifier: Optional[ClassifierHead] = None
|
| 18 |
+
self.retriever: Optional[Retriever] = None
|
| 19 |
+
self.is_initialized = False
|
| 20 |
+
|
| 21 |
+
def initialize(self):
|
| 22 |
+
"""Loads models and initializes the retriever."""
|
| 23 |
+
if self.is_initialized:
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
print(f"Loading embedding model: {settings.MODEL_NAME}...")
|
| 27 |
+
self.embedding_model, self.classifier = get_models(model_id=settings.CLASSIFIER_NAME)
|
| 28 |
+
print("Model loaded.")
|
| 29 |
+
|
| 30 |
+
print("Initializing retriever...")
|
| 31 |
+
cfg = _available(settings.CORPORA_CONFIG)
|
| 32 |
+
if not cfg:
|
| 33 |
+
raise RuntimeError("No corpora files found in data/corpora. Build them first.")
|
| 34 |
+
|
| 35 |
+
self.retriever = Retriever(
|
| 36 |
+
corpora_config=cfg,
|
| 37 |
+
use_reranker=self.use_reranker,
|
| 38 |
+
embedding_model=self.embedding_model
|
| 39 |
+
)
|
| 40 |
+
print("Retriever initialized.")
|
| 41 |
+
self.is_initialized = True
|
| 42 |
+
|
| 43 |
+
def predict(self, query: str, k: int = 10) -> Dict[str, Any]:
|
| 44 |
+
"""
|
| 45 |
+
Runs the full pipeline: Classification -> Retrieval (if medical).
|
| 46 |
+
"""
|
| 47 |
+
if not self.is_initialized:
|
| 48 |
+
self.initialize()
|
| 49 |
+
|
| 50 |
+
classification = predict_query(
|
| 51 |
+
text=[query],
|
| 52 |
+
embedding_model=self.embedding_model,
|
| 53 |
+
classifier_head=self.classifier,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
predictions = classification["prediction"]
|
| 57 |
+
result = {
|
| 58 |
+
"query": query,
|
| 59 |
+
"classification": {
|
| 60 |
+
"prediction": predictions[0],
|
| 61 |
+
"probabilities": {
|
| 62 |
+
cat: prob
|
| 63 |
+
for cat, prob in zip(settings.CATEGORIES, classification['probabilities'])
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"retrieval": []
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
if "medical" in predictions:
|
| 70 |
+
hits = get_candidates(
|
| 71 |
+
query=query,
|
| 72 |
+
retriever=self.retriever,
|
| 73 |
+
k_retrieve=k,
|
| 74 |
+
)
|
| 75 |
+
result["retrieval"] = [asdict(hit) for hit in hits]
|
| 76 |
+
|
| 77 |
+
return result
|
| 78 |
+
|
| 79 |
+
def get_index_progress(self):
|
| 80 |
+
"""Returns (current, total) of the underlying index."""
|
| 81 |
+
if not self.retriever:
|
| 82 |
+
return 0, 0
|
| 83 |
+
return self.retriever.get_index_progress()
|
reason_data_analysis.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple script to analyze healthcare reason data processing
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# Add current directory to path
|
| 10 |
+
sys.path.append('.')
|
| 11 |
+
|
| 12 |
+
def test_data_loading():
|
| 13 |
+
"""Test loading and processing the healthcare reason data"""
|
| 14 |
+
|
| 15 |
+
print("Testing Healthcare Reason Data Processing")
|
| 16 |
+
print("=" * 40)
|
| 17 |
+
|
| 18 |
+
# Load the data
|
| 19 |
+
try:
|
| 20 |
+
df = pd.read_excel('data/reason_for_visit_data.xlsx')
|
| 21 |
+
print(f"✅ Successfully loaded {len(df)} records")
|
| 22 |
+
except Exception as e:
|
| 23 |
+
print(f"❌ Error loading data: {e}")
|
| 24 |
+
return False
|
| 25 |
+
|
| 26 |
+
# Analyze the data
|
| 27 |
+
print(f"\nDataset Info:")
|
| 28 |
+
print(f"Shape: {df.shape}")
|
| 29 |
+
print(f"Columns: {list(df.columns)}")
|
| 30 |
+
|
| 31 |
+
# Show reason distribution
|
| 32 |
+
print(f"\nTop 10 Reasons for Visit:")
|
| 33 |
+
top_reasons = df['Reason For Visit'].value_counts().head(10)
|
| 34 |
+
for reason, count in top_reasons.items():
|
| 35 |
+
print(f" {reason}: {count}")
|
| 36 |
+
|
| 37 |
+
# Test categorization logic
|
| 38 |
+
def map_reason_to_category(reason: str) -> str:
|
| 39 |
+
"""Simple categorization logic"""
|
| 40 |
+
reason_lower = reason.lower()
|
| 41 |
+
|
| 42 |
+
if any(word in reason_lower for word in ['routine', 'nail care', 'calluses']):
|
| 43 |
+
return "ROUTINE_CARE"
|
| 44 |
+
elif any(word in reason_lower for word in ['pain', 'ache', 'sore']):
|
| 45 |
+
return "PAIN_CONDITIONS"
|
| 46 |
+
elif any(word in reason_lower for word in ['sprain', 'wound', 'injury']):
|
| 47 |
+
return "INJURIES"
|
| 48 |
+
elif any(word in reason_lower for word in ['ingrown', 'toenail', 'callus']):
|
| 49 |
+
return "SKIN_CONDITIONS"
|
| 50 |
+
elif any(word in reason_lower for word in ['flat feet', 'plantar', 'fasciitis', 'achilles']):
|
| 51 |
+
return "STRUCTURAL_ISSUES"
|
| 52 |
+
elif any(word in reason_lower for word in ['injection', 'surgical', 'consult', 'postop']):
|
| 53 |
+
return "PROCEDURES"
|
| 54 |
+
else:
|
| 55 |
+
return "PAIN_CONDITIONS" # Default
|
| 56 |
+
|
| 57 |
+
# Apply categorization
|
| 58 |
+
df['Category'] = df['Reason For Visit'].apply(map_reason_to_category)
|
| 59 |
+
|
| 60 |
+
print(f"\nCategory Distribution:")
|
| 61 |
+
category_counts = df['Category'].value_counts()
|
| 62 |
+
for category, count in category_counts.items():
|
| 63 |
+
percentage = (count / len(df)) * 100
|
| 64 |
+
print(f" {category}: {count} ({percentage:.1f}%)")
|
| 65 |
+
|
| 66 |
+
# Show examples for each category
|
| 67 |
+
print(f"\nExample reasons by category:")
|
| 68 |
+
for category in category_counts.index:
|
| 69 |
+
examples = df[df['Category'] == category]['Reason For Visit'].head(3).tolist()
|
| 70 |
+
print(f" {category}:")
|
| 71 |
+
for example in examples:
|
| 72 |
+
print(f" - {example}")
|
| 73 |
+
|
| 74 |
+
return True
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
success = test_data_loading()
|
| 78 |
+
if success:
|
| 79 |
+
print("\n✅ Healthcare reason data analysis completed successfully!")
|
| 80 |
+
else:
|
| 81 |
+
print("\n❌ Healthcare reason data analysis failed!")
|
requirements-admin.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Requirements for Administrative Query Classifier
|
| 2 |
+
pandas>=1.5.0
|
| 3 |
+
scikit-learn>=1.0.0
|
| 4 |
+
setfit>=1.0.0
|
| 5 |
+
sentence-transformers>=2.0.0
|
| 6 |
+
datasets>=2.0.0
|
| 7 |
+
matplotlib>=3.5.0
|
| 8 |
+
seaborn>=0.11.0
|
| 9 |
+
numpy>=1.21.0
|
requirements-train.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
matplotlib
|
| 2 |
+
numpy
|
| 3 |
+
pandas
|
| 4 |
+
sentence-transformers
|
| 5 |
+
torch
|
| 6 |
+
huggingface_hub
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets
|
| 2 |
+
pyarrow
|
| 3 |
+
jsonlines
|
| 4 |
+
tqdm
|
| 5 |
+
rank-bm25
|
| 6 |
+
faiss-cpu
|
| 7 |
+
sentence-transformers
|
| 8 |
+
numpy
|
| 9 |
+
scipy
|
| 10 |
+
scikit-learn
|
| 11 |
+
torch
|
| 12 |
+
huggingface_hub
|
| 13 |
+
gradio
|
| 14 |
+
streamlit
|
retriever/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .search import Retriever
|
retriever/data_schemas.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
@dataclass
|
| 4 |
+
class Doc:
|
| 5 |
+
id: str
|
| 6 |
+
text: str
|
| 7 |
+
title: str | None = None
|
| 8 |
+
meta: dict | None = None
|
retriever/index_bm25.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rank_bm25 import BM25Okapi
|
| 2 |
+
from .utils import tokenize
|
| 3 |
+
|
| 4 |
+
class BM25Index:
|
| 5 |
+
def __init__(self, docs):
|
| 6 |
+
self.docs = docs
|
| 7 |
+
self.corpus_tokens = [tokenize(d.text) for d in docs]
|
| 8 |
+
self.bm25 = BM25Okapi(self.corpus_tokens)
|
| 9 |
+
|
| 10 |
+
def search(self, query: str, k: int = 50):
|
| 11 |
+
q = tokenize(query)
|
| 12 |
+
scores = self.bm25.get_scores(q)
|
| 13 |
+
top = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k]
|
| 14 |
+
return [(self.docs[i], float(scores[i])) for i in top]
|
retriever/index_dense.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# retriever/index_dense.py
|
| 2 |
+
import os
|
| 3 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 4 |
+
|
| 5 |
+
import hashlib
|
| 6 |
+
import threading
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pickle
|
| 9 |
+
import torch
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from sentence_transformers import SentenceTransformer
|
| 12 |
+
from classifier.utils import DEVICE
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import faiss # type: ignore
|
| 16 |
+
_HAS_FAISS = True
|
| 17 |
+
except Exception:
|
| 18 |
+
_HAS_FAISS = False
|
| 19 |
+
|
| 20 |
+
def _chunks(lst, n):
|
| 21 |
+
for i in range(0, len(lst), n):
|
| 22 |
+
yield lst[i:i+n]
|
| 23 |
+
|
| 24 |
+
def _compute_cache_key(docs, model_name):
|
| 25 |
+
"""Compute a hash key for caching based on documents and model."""
|
| 26 |
+
# Create a hash from document IDs/texts and model name
|
| 27 |
+
doc_ids = "".join([d.id for d in docs])
|
| 28 |
+
content = f"{model_name}:{doc_ids}"
|
| 29 |
+
return hashlib.md5(content.encode()).hexdigest()
|
| 30 |
+
|
| 31 |
+
class DenseIndex:
|
| 32 |
+
def __init__(self, docs, model_name="sentence-transformers/embeddinggemma-300m-medical",
|
| 33 |
+
batch_size=64, embedding_model=None, cache_dir=".cache/embeddings"):
|
| 34 |
+
self.docs = docs
|
| 35 |
+
self.batch_size = batch_size
|
| 36 |
+
self.cache_dir = cache_dir
|
| 37 |
+
|
| 38 |
+
# Thread safety
|
| 39 |
+
self.lock = threading.Lock()
|
| 40 |
+
self.ready_count = 0
|
| 41 |
+
self.emb_batches = [] # List of numpy arrays for fallback
|
| 42 |
+
|
| 43 |
+
torch.set_num_threads(1)
|
| 44 |
+
if embedding_model:
|
| 45 |
+
self.model = embedding_model
|
| 46 |
+
self.device = self.model.device
|
| 47 |
+
actual_model_name = getattr(self.model, 'model_card_data', {}).get('base_model', model_name)
|
| 48 |
+
if hasattr(self.model, '_model_card_vars') and 'model_id' in self.model._model_card_vars:
|
| 49 |
+
actual_model_name = self.model._model_card_vars['model_id']
|
| 50 |
+
else:
|
| 51 |
+
self.model = SentenceTransformer(model_name, device=DEVICE)
|
| 52 |
+
self.device = DEVICE
|
| 53 |
+
actual_model_name = model_name
|
| 54 |
+
|
| 55 |
+
self.cache_key = _compute_cache_key(docs, actual_model_name)
|
| 56 |
+
self.cache_path = Path(cache_dir) / f"{self.cache_key}.pkl"
|
| 57 |
+
|
| 58 |
+
# Initialize index structure
|
| 59 |
+
if _HAS_FAISS:
|
| 60 |
+
# We need to know dimension to init FAISS.
|
| 61 |
+
# We'll init it when the first batch arrives or if we load full cache.
|
| 62 |
+
self.index = None
|
| 63 |
+
else:
|
| 64 |
+
self.index = None
|
| 65 |
+
|
| 66 |
+
# Start background ingestion
|
| 67 |
+
self.ingest_thread = threading.Thread(target=self._ingest_embeddings, daemon=True)
|
| 68 |
+
self.ingest_thread.start()
|
| 69 |
+
|
| 70 |
+
def _generate_embeddings(self):
|
| 71 |
+
"""Yields batches of embeddings from cache or computation."""
|
| 72 |
+
texts = [d.text for d in self.docs]
|
| 73 |
+
|
| 74 |
+
# 1. Try full cache first
|
| 75 |
+
if self.cache_path.exists():
|
| 76 |
+
print(f"Loading embeddings from cache: {self.cache_path}")
|
| 77 |
+
try:
|
| 78 |
+
with open(self.cache_path, 'rb') as f:
|
| 79 |
+
full_emb = pickle.load(f)
|
| 80 |
+
print(f"✓ Loaded {len(full_emb)} cached embeddings")
|
| 81 |
+
# Yield as a single large batch
|
| 82 |
+
yield full_emb
|
| 83 |
+
return
|
| 84 |
+
except Exception as e:
|
| 85 |
+
print(f"Cache load failed: {e}, recomputing...")
|
| 86 |
+
|
| 87 |
+
# 2. Partial cache logic
|
| 88 |
+
partial_cache_path = self.cache_path.parent / f"{self.cache_path.stem}.partial.pkl"
|
| 89 |
+
start_index = 0
|
| 90 |
+
existing_embs = []
|
| 91 |
+
|
| 92 |
+
if partial_cache_path.exists():
|
| 93 |
+
try:
|
| 94 |
+
with open(partial_cache_path, 'rb') as f:
|
| 95 |
+
existing_embs = pickle.load(f)
|
| 96 |
+
|
| 97 |
+
# Yield existing chunks
|
| 98 |
+
# We assume existing_embs is a list of batches from previous run
|
| 99 |
+
# But wait, previous implementation saved list of batches.
|
| 100 |
+
# Let's verify if it saved list of batches or vstacked array.
|
| 101 |
+
# Previous impl: pickle.dump(embs, f) where embs is list of arrays.
|
| 102 |
+
|
| 103 |
+
for batch in existing_embs:
|
| 104 |
+
yield batch
|
| 105 |
+
|
| 106 |
+
start_index = sum(len(e) for e in existing_embs)
|
| 107 |
+
except Exception as e:
|
| 108 |
+
existing_embs = []
|
| 109 |
+
start_index = 0
|
| 110 |
+
|
| 111 |
+
# 3. Compute remaining
|
| 112 |
+
texts_to_process = texts[start_index:]
|
| 113 |
+
if not texts_to_process:
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
# We need to keep track of all embs (existing + new) to save partial/full cache
|
| 117 |
+
# But `existing_embs` might be large.
|
| 118 |
+
# We will append new batches to `existing_embs` locally to save partials.
|
| 119 |
+
|
| 120 |
+
with torch.inference_mode():
|
| 121 |
+
total_processed = start_index
|
| 122 |
+
total_batches = (len(texts) + self.batch_size - 1) // self.batch_size
|
| 123 |
+
start_batch = len(existing_embs)
|
| 124 |
+
|
| 125 |
+
for i, part in enumerate(_chunks(texts_to_process, self.batch_size), 1):
|
| 126 |
+
part_emb = self.model.encode(
|
| 127 |
+
part,
|
| 128 |
+
batch_size=self.batch_size,
|
| 129 |
+
normalize_embeddings=True,
|
| 130 |
+
convert_to_numpy=True,
|
| 131 |
+
show_progress_bar=False,
|
| 132 |
+
device=self.device,
|
| 133 |
+
)
|
| 134 |
+
batch_emb = part_emb.astype(np.float32)
|
| 135 |
+
yield batch_emb
|
| 136 |
+
|
| 137 |
+
existing_embs.append(batch_emb)
|
| 138 |
+
total_processed += len(part)
|
| 139 |
+
|
| 140 |
+
# Save partial
|
| 141 |
+
with open(partial_cache_path, 'wb') as f:
|
| 142 |
+
pickle.dump(existing_embs, f)
|
| 143 |
+
|
| 144 |
+
def _ingest_embeddings(self):
|
| 145 |
+
"""Background thread to ingest embeddings from generator."""
|
| 146 |
+
all_embs = []
|
| 147 |
+
|
| 148 |
+
for batch_emb in self._generate_embeddings():
|
| 149 |
+
with self.lock:
|
| 150 |
+
if _HAS_FAISS:
|
| 151 |
+
if self.index is None:
|
| 152 |
+
d = batch_emb.shape[1]
|
| 153 |
+
self.index = faiss.IndexFlatIP(d)
|
| 154 |
+
self.index.add(batch_emb)
|
| 155 |
+
|
| 156 |
+
# We also keep track for fallback or saving
|
| 157 |
+
self.emb_batches.append(batch_emb)
|
| 158 |
+
self.ready_count += len(batch_emb)
|
| 159 |
+
|
| 160 |
+
all_embs.append(batch_emb)
|
| 161 |
+
|
| 162 |
+
# Finalize
|
| 163 |
+
full_emb = np.vstack(all_embs).astype(np.float32)
|
| 164 |
+
|
| 165 |
+
# Save full cache
|
| 166 |
+
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
|
| 167 |
+
with open(self.cache_path, 'wb') as f:
|
| 168 |
+
pickle.dump(full_emb, f)
|
| 169 |
+
print(f"✓ Saved embeddings to cache: {self.cache_path}")
|
| 170 |
+
|
| 171 |
+
# Cleanup partial
|
| 172 |
+
partial_cache_path = self.cache_path.parent / f"{self.cache_path.stem}.partial.pkl"
|
| 173 |
+
if partial_cache_path.exists():
|
| 174 |
+
partial_cache_path.unlink()
|
| 175 |
+
|
| 176 |
+
def search(self, query: str, k: int = 50):
|
| 177 |
+
qv = self.model.encode(
|
| 178 |
+
[query],
|
| 179 |
+
normalize_embeddings=True,
|
| 180 |
+
convert_to_numpy=True,
|
| 181 |
+
show_progress_bar=False,
|
| 182 |
+
device=self.device,
|
| 183 |
+
).astype(np.float32)[0]
|
| 184 |
+
|
| 185 |
+
with self.lock:
|
| 186 |
+
current_count = self.ready_count
|
| 187 |
+
if current_count == 0:
|
| 188 |
+
print("Warning: Index not yet initialized, returning empty results.")
|
| 189 |
+
return []
|
| 190 |
+
|
| 191 |
+
# If we have partial data, we search it.
|
| 192 |
+
if _HAS_FAISS and self.index is not None:
|
| 193 |
+
# FAISS index is updated incrementally
|
| 194 |
+
D, I = self.index.search(qv.reshape(1, -1), min(k, current_count))
|
| 195 |
+
return [(self.docs[int(i)], float(D[0][j])) for j, i in enumerate(I[0]) if i != -1]
|
| 196 |
+
|
| 197 |
+
# NumPy fallback
|
| 198 |
+
# We might have multiple batches, need to stack them for search
|
| 199 |
+
# Optimization: cache the stacked version if it hasn't changed?
|
| 200 |
+
# For now, just stack what we have.
|
| 201 |
+
curr_emb = np.vstack(self.emb_batches)
|
| 202 |
+
|
| 203 |
+
sims = curr_emb @ qv
|
| 204 |
+
effective_k = min(k, len(sims))
|
| 205 |
+
|
| 206 |
+
if effective_k >= len(sims):
|
| 207 |
+
order = np.argsort(-sims)
|
| 208 |
+
else:
|
| 209 |
+
idx = np.argpartition(-sims, kth=effective_k-1)[:effective_k]
|
| 210 |
+
order = idx[np.argsort(-sims[idx])]
|
| 211 |
+
|
| 212 |
+
return [(self.docs[int(i)], float(sims[int(i)])) for i in order]
|
| 213 |
+
|
| 214 |
+
def get_progress(self):
|
| 215 |
+
"""Returns (current_count, total_count) of indexed documents."""
|
| 216 |
+
with self.lock:
|
| 217 |
+
return self.ready_count, len(self.docs)
|
retriever/ingest.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json, pathlib
|
| 2 |
+
from .data_schemas import Doc
|
| 3 |
+
|
| 4 |
+
def load_jsonl(path: str, text_fields=("question","answer")):
|
| 5 |
+
p = pathlib.Path(path)
|
| 6 |
+
docs = []
|
| 7 |
+
with p.open(encoding="utf-8") as f:
|
| 8 |
+
for i, line in enumerate(f):
|
| 9 |
+
row = json.loads(line)
|
| 10 |
+
# Collect fields; allow either "text" or joined fields
|
| 11 |
+
if "text" in row and row["text"]:
|
| 12 |
+
combined = row["text"]
|
| 13 |
+
else:
|
| 14 |
+
combined = " ".join([row.get(tf, "") for tf in text_fields]).strip()
|
| 15 |
+
title = row.get("title") or row.get("category") or ""
|
| 16 |
+
docs.append(Doc(
|
| 17 |
+
id=str(row.get("id", f"{p.stem}:{i}")),
|
| 18 |
+
text=combined,
|
| 19 |
+
title=title,
|
| 20 |
+
meta=row
|
| 21 |
+
))
|
| 22 |
+
return docs
|
retriever/rrf.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
|
| 3 |
+
def rrf(rank_lists, k=10, K=60):
|
| 4 |
+
scores = defaultdict(float)
|
| 5 |
+
id2doc = {}
|
| 6 |
+
for rl in rank_lists:
|
| 7 |
+
for r, (doc, _) in enumerate(rl):
|
| 8 |
+
id2doc[doc.id] = doc
|
| 9 |
+
scores[doc.id] += 1.0 / (K + r + 1)
|
| 10 |
+
ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:k]
|
| 11 |
+
return [(id2doc[i], s) for i, s in ranked]
|
retriever/search.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .index_bm25 import BM25Index
|
| 2 |
+
from .index_dense import DenseIndex
|
| 3 |
+
from .rrf import rrf
|
| 4 |
+
try:
|
| 5 |
+
from .rerank import CrossEncoderReranker
|
| 6 |
+
except Exception:
|
| 7 |
+
CrossEncoderReranker = None
|
| 8 |
+
from .ingest import load_jsonl
|
| 9 |
+
|
| 10 |
+
class Retriever:
|
| 11 |
+
def __init__(self, corpora_config, use_reranker=False, embedding_model=None):
|
| 12 |
+
self.corpora = {}
|
| 13 |
+
docs_all = []
|
| 14 |
+
for name, cfg in corpora_config.items():
|
| 15 |
+
docs = load_jsonl(cfg["path"], tuple(cfg.get("text_fields", ("question","answer"))))
|
| 16 |
+
self.corpora[name] = docs
|
| 17 |
+
docs_all.extend(docs)
|
| 18 |
+
self.bm25 = BM25Index(docs_all)
|
| 19 |
+
self.dense = DenseIndex(docs_all, embedding_model=embedding_model)
|
| 20 |
+
self.reranker = CrossEncoderReranker() if (use_reranker and CrossEncoderReranker) else None
|
| 21 |
+
|
| 22 |
+
def retrieve(self, query, k=10, for_ui=True):
|
| 23 |
+
bm = self.bm25.search(query, k=100)
|
| 24 |
+
de = self.dense.search(query, k=100)
|
| 25 |
+
fused = rrf([bm, de], k=max(k, 20))
|
| 26 |
+
if self.reranker:
|
| 27 |
+
reranked = self.reranker.rerank(query, [d for d, _ in fused])[:k]
|
| 28 |
+
results = [(d, float(s)) for d, s in reranked]
|
| 29 |
+
else:
|
| 30 |
+
results = fused[:k]
|
| 31 |
+
if not for_ui:
|
| 32 |
+
return results
|
| 33 |
+
return [{
|
| 34 |
+
"id": d.id,
|
| 35 |
+
"title": d.title,
|
| 36 |
+
"snippet": d.text[:300] + ("..." if len(d.text) > 300 else ""),
|
| 37 |
+
"score": s,
|
| 38 |
+
"meta": d.meta
|
| 39 |
+
} for d, s in results]
|
| 40 |
+
|
| 41 |
+
def get_index_progress(self):
|
| 42 |
+
"""Returns (current, total) from dense index."""
|
| 43 |
+
return self.dense.get_progress()
|
retriever/utils.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
# Simple non-word splitter (keeps letters/numbers, splits on punctuation/whitespace)
|
| 4 |
+
_WS = re.compile(r"\W+", flags=re.UNICODE)
|
| 5 |
+
|
| 6 |
+
def tokenize(s: str) -> list[str]:
|
| 7 |
+
"""
|
| 8 |
+
Lowercase + split on non-word chars. Returns [] for None/empty.
|
| 9 |
+
Used by BM25 to build the tokenized corpus and query.
|
| 10 |
+
"""
|
| 11 |
+
if not s:
|
| 12 |
+
return []
|
| 13 |
+
return [t for t in _WS.split(s.lower()) if t]
|
| 14 |
+
|
| 15 |
+
|
scripts/query.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import readline
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
from pyserini.search.lucene import LuceneSearcher
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
index_dir = sys.argv[1] if len(sys.argv) > 1 else "indexes/pubmed"
|
| 12 |
+
|
| 13 |
+
searcher = LuceneSearcher(index_dir)
|
| 14 |
+
|
| 15 |
+
print(f"Loaded {searcher.num_docs} documents from {index_dir}")
|
| 16 |
+
print(f"(Ctrl-D or 'quit' to exit)\n")
|
| 17 |
+
|
| 18 |
+
while True:
|
| 19 |
+
try:
|
| 20 |
+
query = input("PubMed> ").strip()
|
| 21 |
+
if not query or query.lower() in ['quit', 'exit']:
|
| 22 |
+
break
|
| 23 |
+
|
| 24 |
+
hits = searcher.search(query, k=10)
|
| 25 |
+
|
| 26 |
+
print(f"{len(hits)}/{searcher.num_docs} matching documents found\n")
|
| 27 |
+
|
| 28 |
+
if not hits:
|
| 29 |
+
print("No results found.\n")
|
| 30 |
+
|
| 31 |
+
continue
|
| 32 |
+
|
| 33 |
+
for i, hit in enumerate(hits, 1):
|
| 34 |
+
doc = searcher.doc(hit.docid)
|
| 35 |
+
|
| 36 |
+
raw = json.loads(doc.raw())
|
| 37 |
+
|
| 38 |
+
title = raw.get('title', '')
|
| 39 |
+
contents = raw.get('contents', '')
|
| 40 |
+
|
| 41 |
+
abstract = contents[len(title):] if contents.startswith(title) else contents
|
| 42 |
+
|
| 43 |
+
print(f"{i}. PMID {hit.docid} \"{title}\" (score: {hit.score:.4f})")
|
| 44 |
+
print(f" {abstract[:120]}...\n")
|
| 45 |
+
|
| 46 |
+
except EOFError:
|
| 47 |
+
print("\nBye!")
|
| 48 |
+
|
| 49 |
+
break
|
| 50 |
+
|
| 51 |
+
except KeyboardInterrupt:
|
| 52 |
+
print("\nBye!")
|
| 53 |
+
|
| 54 |
+
break
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
main()
|