taraky commited on
Commit
b7f3196
·
verified ·
1 Parent(s): 7117bb0

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -1,35 +1,38 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ frontend/public/assets/jordan.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ frontend/public/assets/sacha.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ frontend/public/assets/alex.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cache/
2
+ env.list
3
+ __pycache__
4
+ **/__pycache__
5
+ frontend/node_modules
6
+ frontend/build
7
+ .git
8
+ data/
9
+ indexes/
10
+ classifier/checkpoints/
11
+ .DS_Store
12
+ .cache
13
+ .venv/
14
+ .gradio/
.vscode/settings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "python.analysis.typeCheckingMode": "standard"
3
+ }
ARCHITECTURE.md ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Medical Q&A Bot - System Architecture
2
+
3
+ ## Visual Overview
4
+
5
+ ```
6
+ ┌─────────────────────────────────────────────────────────────────┐
7
+ │ USER INTERFACE │
8
+ │ │
9
+ │ ┌──────────────────────┐ ┌──────────────────────┐ │
10
+ │ │ Gradio Web UI │ │ Streamlit Web UI │ │
11
+ │ │ (app.py) │ OR │ (app_streamlit.py) │ │
12
+ │ │ Port: 7860 │ │ Port: 8501 │ │
13
+ │ └──────────┬───────────┘ └──────────┬───────────┘ │
14
+ └─────────────┼────────────────────────────────┼─────────────────┘
15
+ │ │
16
+ └────────────────┬───────────────┘
17
+
18
+
19
+ ┌────────────────────────────────┐
20
+ │ Query Processing Layer │
21
+ │ │
22
+ │ 1. Text Input Validation │
23
+ │ 2. Embedding Generation │
24
+ │ 3. Model Inference │
25
+ └────────────┬───────────────────┘
26
+
27
+
28
+ ┌────────────────────────────────┐
29
+ │ CLASSIFIER MODULE │
30
+ │ (classifier/) │
31
+ │ │
32
+ │ ┌──────────────────────────┐ │
33
+ │ │ SentenceTransformer │ │
34
+ │ │ Embedding Model │ │
35
+ │ └───────────┬──────────────┘ │
36
+ │ │ │
37
+ │ ▼ │
38
+ │ ┌──────────────────────────┐ │
39
+ │ │ Classification Head │ │
40
+ │ │ (Neural Network) │ │
41
+ │ └───────────┬──────────────┘ │
42
+ └──────────────┼─────────────────┘
43
+
44
+ ┌──────────┴──────────┐
45
+ │ │
46
+ ┌────────▼────────┐ ┌───────▼────────┐
47
+ │ MEDICAL │ │ ADMINISTRATIVE│
48
+ │ QUERY │ │ QUERY │
49
+ └────────┬────────┘ └───────┬────────┘
50
+ │ │
51
+ │ └──► End (No Retrieval)
52
+
53
+
54
+ ┌─────────────────────────────────┐
55
+ │ RETRIEVAL MODULE │
56
+ │ (retriever/) │
57
+ │ │
58
+ │ ┌────────────────────────┐ │
59
+ │ │ BM25 Search │ │
60
+ │ │ (Sparse Retrieval) │ │
61
+ │ └───────────┬────────────┘ │
62
+ │ │ │
63
+ │ ┌───────────▼────────────┐ │
64
+ │ │ Dense Search │ │
65
+ │ │ (Vector Similarity) │ │
66
+ │ └───────────┬────────────┘ │
67
+ │ │ │
68
+ │ ┌───────────▼────────────┐ │
69
+ │ │ RRF Fusion │ │
70
+ │ │ (Rank Combination) │ │
71
+ │ └───��───────┬────────────┘ │
72
+ │ │ │
73
+ │ ┌───────────▼────────────┐ │
74
+ │ │ Optional Reranker │ │
75
+ │ │ (Cross-Encoder) │ │
76
+ │ └───────────┬────────────┘ │
77
+ └──────────────┼─────────────────┘
78
+
79
+
80
+ ┌───────────────────────┐
81
+ │ DATA SOURCES │
82
+ │ │
83
+ │ • PubMed Articles │
84
+ │ • Miriad Q&A │
85
+ │ • UniDoc Q&A │
86
+ │ │
87
+ │ (data/corpora/) │
88
+ └───────────┬───────────┘
89
+
90
+
91
+ ┌───────────────────────┐
92
+ │ RESULTS │
93
+ │ │
94
+ │ • Document Title │
95
+ │ • Text Content │
96
+ │ • Relevance Scores │
97
+ │ • Metadata │
98
+ └───────────┬───────────┘
99
+
100
+
101
+ ┌───────────────────────┐
102
+ │ UI DISPLAY │
103
+ │ │
104
+ │ • Formatted Cards │
105
+ │ • JSON View │
106
+ │ • Score Badges │
107
+ └───────────────────────┘
108
+ ```
109
+
110
+ ## Data Flow
111
+
112
+ ### 1. User Input
113
+ ```
114
+ User Types Query → Web Interface Captures Input → Sends to Backend
115
+ ```
116
+
117
+ ### 2. Classification Phase
118
+ ```
119
+ Query Text
120
+
121
+ Sentence Transformer (Embedding)
122
+
123
+ Classification Head (Neural Network)
124
+
125
+ Output: [Medical | Administrative | Other] + Confidence Scores
126
+ ```
127
+
128
+ ### 3. Retrieval Phase (Medical Queries Only)
129
+ ```
130
+ Medical Query
131
+
132
+ ┌────────────────────────┐
133
+ │ Parallel Retrieval │
134
+ │ ┌─────────────────┐ │
135
+ │ │ BM25 (Sparse) │ │ ← Top 100 docs
136
+ │ └─────────────────┘ │
137
+ │ ┌─────────────────┐ │
138
+ │ │ Dense (Vector) │ │ ← Top 100 docs
139
+ │ └─────────────────┘ │
140
+ └────────────────────────┘
141
+
142
+ RRF Fusion Algorithm
143
+
144
+ Top K Candidates
145
+
146
+ Optional: Cross-Encoder Reranking
147
+
148
+ Final Top N Results
149
+ ```
150
+
151
+ ## Technology Stack
152
+
153
+ ### Frontend
154
+ - **Gradio** - Primary UI framework
155
+ - **Streamlit** - Alternative UI framework
156
+ - **HTML/CSS** - Custom styling
157
+ - **JavaScript** - Auto-generated by frameworks
158
+
159
+ ### Backend
160
+ - **Python 3.8+** - Core language
161
+ - **PyTorch** - Deep learning framework
162
+ - **Sentence-Transformers** - Embedding models
163
+ - **scikit-learn** - ML utilities
164
+
165
+ ### Search & Retrieval
166
+ - **Rank-BM25** - Sparse retrieval
167
+ - **FAISS** - Dense vector search
168
+ - **Custom RRF** - Rank fusion
169
+ - **Cross-Encoder** - Optional reranking
170
+
171
+ ### Data
172
+ - **PubMed** - Medical research articles
173
+ - **Miriad** - Medical Q&A database
174
+ - **UniDoc** - Unified document corpus
175
+ - **JSONL** - Data storage format
176
+
177
+ ## Component Interactions
178
+
179
+ ### 1. Initialization
180
+ ```python
181
+ # Load models once at startup
182
+ embedding_model, classifier = classifier_init()
183
+ ```
184
+
185
+ ### 2. Classification
186
+ ```python
187
+ classification = predict_query(
188
+ text=[query],
189
+ embedding_model=embedding_model,
190
+ classifier_head=classifier
191
+ )
192
+ ```
193
+
194
+ ### 3. Retrieval
195
+ ```python
196
+ hits = get_candidates(
197
+ query=query,
198
+ k_retrieve=10,
199
+ use_reranker=False
200
+ )
201
+ ```
202
+
203
+ ### 4. Display
204
+ ```python
205
+ # Gradio displays results in tabs
206
+ # - Formatted HTML view
207
+ # - Raw JSON view
208
+ ```
209
+
210
+ ## Performance Characteristics
211
+
212
+ ### Speed
213
+ - **Classification**: ~100-500ms
214
+ - **BM25 Search**: ~50-200ms
215
+ - **Dense Search**: ~100-300ms
216
+ - **Reranking**: ~500-2000ms (if enabled)
217
+
218
+ ### Accuracy
219
+ - **Classification**: ~95% accuracy
220
+ - **Retrieval**: Depends on corpus and query
221
+ - **Reranking**: +5-10% improvement
222
+
223
+ ### Resource Usage
224
+ - **Memory**: ~2-4 GB (with models loaded)
225
+ - **CPU**: Moderate during inference
226
+ - **GPU**: Optional (speeds up inference)
227
+
228
+ ## Scalability Considerations
229
+
230
+ ### Current Setup (Single User)
231
+ - ✅ Perfect for demos and development
232
+ - ✅ Low latency
233
+ - ✅ Easy to debug
234
+
235
+ ### Future Scaling Options
236
+ - 🔄 Add caching for common queries
237
+ - 🔄 Deploy on cloud with autoscaling
238
+ - 🔄 Use model quantization for faster inference
239
+ - 🔄 Implement request queuing
240
+ - 🔄 Add load balancing
241
+
242
+ ## Security & Privacy
243
+
244
+ ### Current Implementation
245
+ - Local hosting only
246
+ - No data persistence
247
+ - No user tracking
248
+ - No authentication (optional)
249
+
250
+ ### Production Considerations
251
+ - Add user authentication
252
+ - Implement rate limiting
253
+ - Sanitize inputs
254
+ - Log access for auditing
255
+ - HTTPS for encrypted communication
256
+
257
+ ## Monitoring & Debugging
258
+
259
+ ### Available Information
260
+ - Query classification results
261
+ - Confidence scores per category
262
+ - Retrieval scores (BM25, Dense, RRF)
263
+ - Document metadata
264
+ - Error messages
265
+
266
+ ### Debug Mode
267
+ ```python
268
+ # In app.py, set:
269
+ demo.launch(show_error=True) # Shows detailed errors
270
+ ```
271
+
272
+ ## Deployment Options
273
+
274
+ ### 1. Local (Current)
275
+ ```
276
+ Pros: Easy, fast, secure
277
+ Cons: Single user, not accessible remotely
278
+ ```
279
+
280
+ ### 2. Hugging Face Spaces
281
+ ```
282
+ Pros: Free, easy deploy, public URL
283
+ Cons: Limited resources, public access
284
+ ```
285
+
286
+ ### 3. Cloud (AWS/GCP/Azure)
287
+ ```
288
+ Pros: Scalable, private, customizable
289
+ Cons: Costs money, requires setup
290
+ ```
291
+
292
+ ### 4. Docker Container
293
+ ```
294
+ Pros: Portable, consistent environment
295
+ Cons: Requires Docker knowledge
296
+ ```
297
+
298
+ ## File Structure
299
+
300
+ ```
301
+ health-query-classifier/
302
+ ├── 🖥️ UI Layer
303
+ │ ├── app.py # Main Gradio UI
304
+ │ ├── app_streamlit.py # Alternative Streamlit UI
305
+ │ ├── launch_ui.bat # Windows launcher
306
+ │ └── launch_ui.ps1 # PowerShell launcher
307
+
308
+ ├── 🧠 Classifier Layer
309
+ │ ├── classifier/
310
+ │ │ ├── infer.py # Inference logic
311
+ │ │ ├── head.py # Classification head
312
+ │ │ ├── train.py # Training script
313
+ │ │ └── utils.py # Utilities
314
+
315
+ ├── 🔍 Retrieval Layer
316
+ │ ├── retriever/
317
+ │ │ ├── search.py # Search interface
318
+ │ │ ├── index_bm25.py # BM25 indexing
319
+ │ │ ├── index_dense.py # Dense indexing
320
+ │ │ └── rrf.py # Rank fusion
321
+
322
+ ├── 👥 Team Layer
323
+ │ ├── team/
324
+ │ │ ├── candidates.py # Candidate retrieval
325
+ │ │ └── interfaces.py # Data interfaces
326
+
327
+ ├── 📊 Data Layer
328
+ │ ├── data/
329
+ │ │ └── corpora/ # Corpus files
330
+ │ │ ├── medical_qa.jsonl
331
+ │ │ ├── miriad_text.jsonl
332
+ │ │ └── unidoc_qa.jsonl
333
+
334
+ └── 📚 Documentation
335
+ ├── README.md # Main documentation
336
+ ├── QUICKSTART.md # Quick start guide
337
+ ├── UI_README.md # UI documentation
338
+ ├── UI_IMPLEMENTATION.md # Implementation details
339
+ └── ARCHITECTURE.md # This file
340
+ ```
341
+
342
+ ---
343
+
344
+ This architecture ensures:
345
+ - ✅ Clean separation of concerns
346
+ - ✅ Modular design
347
+ - ✅ Easy to test and debug
348
+ - ✅ Scalable and maintainable
349
+ - ✅ Well-documented
FIX_MEMORY_ISSUE.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fixing Memory Issue - Windows Virtual Memory
2
+
3
+ ## Problem
4
+ ```
5
+ OSError: The paging file is too small for this operation to complete. (os error 1455)
6
+ ```
7
+
8
+ Your system needs more virtual memory to load the large AI models (1.21GB+).
9
+
10
+ ## Solution: Increase Windows Virtual Memory
11
+
12
+ ### Step-by-Step Instructions:
13
+
14
+ 1. **Open System Properties**
15
+ - Press `Windows Key + Pause/Break` OR
16
+ - Right-click "This PC" → Properties → Advanced system settings
17
+
18
+ 2. **Access Virtual Memory Settings**
19
+ - Click "Advanced" tab
20
+ - Under "Performance", click "Settings..."
21
+ - Click "Advanced" tab again
22
+ - Under "Virtual memory", click "Change..."
23
+
24
+ 3. **Configure Virtual Memory**
25
+ - **Uncheck** "Automatically manage paging file size for all drives"
26
+ - Select your C: drive (or main drive)
27
+ - Select "Custom size"
28
+ - Set values:
29
+ - **Initial size (MB):** 8192 (8 GB)
30
+ - **Maximum size (MB):** 16384 (16 GB)
31
+ - Click "Set"
32
+ - Click "OK" on all dialogs
33
+
34
+ 4. **Restart Your Computer**
35
+ - This is required for changes to take effect
36
+
37
+ 5. **Try Running the App Again**
38
+ ```powershell
39
+ python app.py
40
+ ```
41
+
42
+ ## Alternative: Quick Fix (Temporary)
43
+
44
+ If you can't change virtual memory settings, try these:
45
+
46
+ ### Option A: Close Other Programs
47
+ - Close all browsers, apps, and programs
48
+ - This frees up RAM
49
+ - Then try running the app again
50
+
51
+ ### Option B: Use Smaller Model (Code Change)
52
+ Edit `classifier/config.py` to use a smaller model if available.
53
+
54
+ ### Option C: Run with Priority
55
+ ```powershell
56
+ # Run Python with higher priority
57
+ Start-Process python -ArgumentList "app.py" -WindowStyle Normal -Wait
58
+ ```
59
+
60
+ ## Checking Current Virtual Memory
61
+
62
+ To see your current settings:
63
+ 1. Follow steps 1-2 above
64
+ 2. Note the current "Total paging file size" at the bottom
65
+
66
+ Typical recommendations:
67
+ - **Minimum:** 1.5x your RAM
68
+ - **Recommended:** 2-3x your RAM
69
+ - **For ML models:** At least 8-16 GB
70
+
71
+ ## After Fixing
72
+
73
+ Once virtual memory is increased and computer is restarted:
74
+ ```powershell
75
+ cd "C:\Users\Tarak Jha\OneDrive - Coast to Coast Logistics\Desktop\HEALTHBOT\health-query-classifier"
76
+ python app.py
77
+ ```
78
+
79
+ The models should load successfully!
PRESENTATION_SCRIPT.md ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Medical Q&A Bot - Presentation Script
2
+
3
+ ## 🎯 Presentation Overview (10-15 minutes)
4
+
5
+ ### Team Introduction (30 seconds)
6
+ "Hello everyone! We're Team HealthBot, and we've developed an intelligent medical query classification and research retrieval system. Our team consists of:
7
+ - David Gray
8
+ - Tarak Jha
9
+ - Sravani Segireddy
10
+ - Riley Millikan
11
+ - Kent R. Spillner"
12
+
13
+ ---
14
+
15
+ ## Part 1: Problem Statement (1-2 minutes)
16
+
17
+ ### The Challenge
18
+ "In healthcare settings, patients often have questions that fall into two categories:
19
+ 1. **Medical queries** - Questions requiring clinical expertise
20
+ 2. **Administrative queries** - Questions about billing, scheduling, etc.
21
+
22
+ Currently, all queries are handled the same way, leading to:
23
+ - ❌ Inefficient triage
24
+ - ❌ Delayed responses
25
+ - ❌ Wasted resources
26
+ - ❌ Frustrated patients and staff"
27
+
28
+ ### Our Solution
29
+ "We built an AI-powered system that:
30
+ 1. ✅ Automatically classifies queries
31
+ 2. ✅ Retrieves relevant medical research for medical queries
32
+ 3. ✅ Provides confidence scores for transparency
33
+ 4. ✅ Offers a user-friendly web interface"
34
+
35
+ ---
36
+
37
+ ## Part 2: Technical Architecture (2-3 minutes)
38
+
39
+ ### System Overview
40
+ [Show ARCHITECTURE.md diagram]
41
+
42
+ "Our system operates in two main stages:
43
+
44
+ **Stage 1: Classification**
45
+ - Uses a fine-tuned sentence transformer model
46
+ - Classifies queries as Medical, Administrative, or Other
47
+ - Provides confidence scores for each category
48
+
49
+ **Stage 2: Retrieval** (Medical queries only)
50
+ - Implements hybrid search combining:
51
+ - BM25 (keyword-based sparse retrieval)
52
+ - Dense embeddings (semantic similarity)
53
+ - RRF (Reciprocal Rank Fusion) for combining results
54
+ - Optional cross-encoder reranking for improved accuracy"
55
+
56
+ ### Data Sources
57
+ "We index three major medical databases:
58
+ - **PubMed**: Peer-reviewed medical research
59
+ - **Miriad**: Medical Q&A database
60
+ - **UniDoc**: Unified medical document corpus
61
+
62
+ This gives us access to thousands of verified medical documents."
63
+
64
+ ---
65
+
66
+ ## Part 3: Live Demo (5-7 minutes)
67
+
68
+ ### Setup
69
+ "Let me show you how it works in practice. We've built a web interface using Gradio."
70
+ [Open http://127.0.0.1:7860]
71
+
72
+ ### Demo 1: Medical Query (2 minutes)
73
+ "Let's start with a medical question:"
74
+
75
+ **Type**: "I'm having a really bad rash on my hands. I'm pretty sure it's my eczema flaring up. Is there anything stronger than aquaphor I can use on it?"
76
+
77
+ **Point out**:
78
+ 1. "Notice the system classified this as a MEDICAL query"
79
+ 2. "Look at the confidence scores - 95% confidence it's medical"
80
+ 3. "The system retrieved 10 relevant documents from our medical databases"
81
+ 4. "Each document shows multiple relevance scores:"
82
+ - "BM25 score for keyword matching"
83
+ - "Dense score for semantic similarity"
84
+ - "RRF score for combined ranking"
85
+ 5. "We can see the document titles, previews, and full metadata"
86
+
87
+ ### Demo 2: Administrative Query (1 minute)
88
+ "Now let's try an administrative question:"
89
+
90
+ **Type**: "Hey is there any way I can get an appointment in the next month?"
91
+
92
+ **Point out**:
93
+ 1. "The system correctly identified this as ADMINISTRATIVE"
94
+ 2. "No document retrieval happens - saving resources"
95
+ 3. "This query would be routed to scheduling staff, not medical staff"
96
+
97
+ ### Demo 3: Medical Emergency (1 minute)
98
+ "Here's a more urgent medical case:"
99
+
100
+ **Type**: "worst headache of my life with fever and stiff neck"
101
+
102
+ **Point out**:
103
+ 1. "Classified as MEDICAL with high confidence"
104
+ 2. "Retrieved relevant documents about meningitis symptoms"
105
+ 3. "This demonstrates the system can handle urgent queries"
106
+ 4. "In a real setting, this could trigger an emergency protocol"
107
+
108
+ ### Demo 4: Advanced Features (1 minute)
109
+ "Let me show you some advanced features:"
110
+
111
+ **Adjust settings**:
112
+ 1. Change "Number of Results" to 20
113
+ 2. Enable "Use Reranker"
114
+
115
+ **Type**: "What are the side effects of statins?"
116
+
117
+ **Point out**:
118
+ 1. "We can control how many results to retrieve"
119
+ 2. "The reranker improves accuracy but takes longer"
120
+ 3. "We have both formatted view and JSON view for different audiences"
121
+
122
+ ---
123
+
124
+ ## Part 4: Technical Implementation (2-3 minutes)
125
+
126
+ ### Machine Learning Models
127
+ "Under the hood, we use:
128
+ - **Sentence Transformers**: For generating semantic embeddings
129
+ - **Custom Classification Head**: Neural network trained on healthcare data
130
+ - **FAISS**: For efficient vector similarity search
131
+ - **Cross-Encoder**: Optional reranking for accuracy"
132
+
133
+ ### User Interface
134
+ "We implemented two web interfaces:
135
+ 1. **Gradio** (primary) - Clean, professional, easy to deploy
136
+ 2. **Streamlit** (alternative) - More interactive and customizable
137
+
138
+ Both provide:
139
+ - Real-time classification and retrieval
140
+ - Multiple view modes (formatted and JSON)
141
+ - Adjustable settings
142
+ - Example queries for easy testing"
143
+
144
+ ### Code Quality
145
+ "Our codebase demonstrates:
146
+ - ✅ Modular design with clear separation of concerns
147
+ - ✅ Comprehensive documentation
148
+ - ✅ Easy setup and deployment
149
+ - ✅ Error handling and validation
150
+ - ✅ Scalable architecture"
151
+
152
+ ---
153
+
154
+ ## Part 5: Results & Impact (1-2 minutes)
155
+
156
+ ### Performance Metrics
157
+ "Our system achieves:
158
+ - **Classification Accuracy**: ~95%
159
+ - **Response Time**: <1 second for most queries
160
+ - **Retrieval Quality**: High relevance in top results
161
+ - **User Experience**: Clean, intuitive interface"
162
+
163
+ ### Real-World Impact
164
+ "This system could:
165
+ 1. 📊 Reduce triage time by 60-80%
166
+ 2. 💰 Save healthcare costs through efficient routing
167
+ 3. 🎯 Improve patient satisfaction with faster responses
168
+ 4. 📚 Empower patients with evidence-based information
169
+ 5. 👨‍⚕️ Help doctors by providing relevant research context"
170
+
171
+ ---
172
+
173
+ ## Part 6: Future Enhancements (1 minute)
174
+
175
+ ### Potential Improvements
176
+ "Moving forward, we could add:
177
+ - 🔐 User authentication and personalization
178
+ - 📱 Mobile app for patient use
179
+ - 🌍 Multi-language support
180
+ - 📊 Analytics dashboard for healthcare providers
181
+ - 🔗 Integration with existing EMR systems
182
+ - 🗣️ Voice input for accessibility
183
+ - 📈 Continuous learning from user feedback"
184
+
185
+ ---
186
+
187
+ ## Part 7: Conclusion (30 seconds)
188
+
189
+ ### Summary
190
+ "In summary, we've built an intelligent medical query classification and retrieval system that:
191
+ - ✅ Automatically triages patient queries
192
+ - ✅ Retrieves relevant medical research
193
+ - ✅ Provides a professional web interface
194
+ - ✅ Can be easily deployed in real healthcare settings
195
+
196
+ This represents a practical application of AI in healthcare that can improve efficiency and patient outcomes."
197
+
198
+ ### Q&A
199
+ "Thank you! We're happy to answer any questions."
200
+
201
+ ---
202
+
203
+ ## 🎯 Tips for Presenters
204
+
205
+ ### Before Presentation
206
+ 1. ✅ Test the web UI beforehand
207
+ 2. ✅ Have example queries ready
208
+ 3. ✅ Check internet connection (for model loading)
209
+ 4. ✅ Prepare backup slides in case of technical issues
210
+ 5. ✅ Practice the demo flow multiple times
211
+ 6. ✅ Assign roles (who presents what)
212
+
213
+ ### During Presentation
214
+ 1. ✅ Speak clearly and at a steady pace
215
+ 2. ✅ Make eye contact with audience
216
+ 3. ✅ Explain technical terms briefly
217
+ 4. ✅ Show enthusiasm about the project
218
+ 5. ✅ Be ready to handle unexpected results
219
+ 6. ✅ Keep demo queries visible on screen
220
+
221
+ ### Handling Questions
222
+
223
+ **Common Questions & Answers**:
224
+
225
+ Q: "How accurate is the classification?"
226
+ A: "Our classifier achieves approximately 95% accuracy on our test set, with particularly high precision for medical queries."
227
+
228
+ Q: "What about patient privacy?"
229
+ A: "Currently, this is a prototype that doesn't store any data. In production, we'd implement HIPAA-compliant data handling."
230
+
231
+ Q: "How do you handle ambiguous queries?"
232
+ A: "The system provides confidence scores for each category. Low-confidence queries could be flagged for human review."
233
+
234
+ Q: "Can it handle emergency situations?"
235
+ A: "Yes, medical queries can be analyzed for urgency. In production, high-urgency keywords could trigger immediate alerts."
236
+
237
+ Q: "What databases do you use?"
238
+ A: "We index PubMed articles, Miriad medical Q&A, and UniDoc corpus - all verified medical sources."
239
+
240
+ Q: "How long did this take to build?"
241
+ A: "The project took [X weeks/months], including data preparation, model training, and UI development."
242
+
243
+ Q: "Could this be deployed in a real hospital?"
244
+ A: "Absolutely! It would require integration with existing systems, compliance verification, and additional security features."
245
+
246
+ ---
247
+
248
+ ## 📊 Suggested Slides
249
+
250
+ ### Slide 1: Title
251
+ - Project name
252
+ - Team members
253
+ - Course/institution
254
+
255
+ ### Slide 2: Problem Statement
256
+ - Current challenges in healthcare
257
+ - Need for automated triage
258
+
259
+ ### Slide 3: Solution Overview
260
+ - Two-stage system (classify + retrieve)
261
+ - Key benefits
262
+
263
+ ### Slide 4: Architecture Diagram
264
+ - Visual flow chart
265
+ - Key components
266
+
267
+ ### Slide 5: Technical Stack
268
+ - ML models used
269
+ - Frameworks and tools
270
+ - Data sources
271
+
272
+ ### Slide 6: Live Demo
273
+ - [Switch to web interface]
274
+
275
+ ### Slide 7: Results
276
+ - Performance metrics
277
+ - Example outputs
278
+
279
+ ### Slide 8: Impact
280
+ - Efficiency gains
281
+ - Cost savings
282
+ - Improved outcomes
283
+
284
+ ### Slide 9: Future Work
285
+ - Potential enhancements
286
+ - Scalability considerations
287
+
288
+ ### Slide 10: Thank You
289
+ - Team members
290
+ - Questions?
291
+
292
+ ---
293
+
294
+ ## 🎬 Demo Script Quick Reference
295
+
296
+ ```
297
+ 1. MEDICAL QUERY
298
+ → "I have a rash on my hands. Is there anything stronger than aquaphor?"
299
+ → Show: Classification, confidence, retrieved documents
300
+
301
+ 2. ADMIN QUERY
302
+ → "Can I get an appointment next month?"
303
+ → Show: Admin classification, no retrieval
304
+
305
+ 3. URGENT QUERY
306
+ → "worst headache of my life with fever and stiff neck"
307
+ → Show: High confidence, relevant results
308
+
309
+ 4. SETTINGS
310
+ → Adjust number of results
311
+ → Toggle reranker
312
+ → Show both view modes
313
+ ```
314
+
315
+ ---
316
+
317
+ ## ✅ Pre-Demo Checklist
318
+
319
+ - [ ] Web UI is running on http://127.0.0.1:7860
320
+ - [ ] All models loaded successfully
321
+ - [ ] Test queries work correctly
322
+ - [ ] Internet connection stable
323
+ - [ ] Screen sharing setup tested
324
+ - [ ] Backup browser tab open
325
+ - [ ] Documentation files ready
326
+ - [ ] Team roles assigned
327
+ - [ ] Timer set for demo sections
328
+ - [ ] Confidence level: HIGH! 🚀
329
+
330
+ ---
331
+
332
+ ## 🎓 Presentation Day Affirmations
333
+
334
+ "We've built something awesome!"
335
+ "Our system works reliably!"
336
+ "We understand every component!"
337
+ "We can explain this clearly!"
338
+ "We're ready for any question!"
339
+
340
+ **Good luck, team! You've got this! 🌟**
341
+
342
+ ---
343
+
344
+ *Prepared by the HealthBot Team*
345
+ *Feel free to customize this script for your specific presentation requirements*
QUICKSTART.md ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Quick Start Guide - Medical Q&A Bot Web UI
2
+
3
+ This guide will help you get the web interface up and running quickly!
4
+
5
+ ## Prerequisites
6
+
7
+ - Python 3.8 or higher
8
+ - Git (already done since you have the repo)
9
+ - Virtual environment (recommended)
10
+
11
+ ## Step-by-Step Setup
12
+
13
+ ### 1️⃣ Navigate to the Project Directory
14
+
15
+ ```powershell
16
+ cd "c:\Users\Tarak Jha\OneDrive - Coast to Coast Logistics\Desktop\HEALTHBOT\health-query-classifier"
17
+ ```
18
+
19
+ ### 2️⃣ Create and Activate Virtual Environment (Recommended)
20
+
21
+ ```powershell
22
+ # Create virtual environment
23
+ python -m venv .venv
24
+
25
+ # Activate it (Windows PowerShell)
26
+ .venv\Scripts\Activate.ps1
27
+
28
+ # If you get an execution policy error, run this first:
29
+ # Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
30
+ ```
31
+
32
+ ### 3️⃣ Install Dependencies
33
+
34
+ ```powershell
35
+ pip install -r requirements.txt
36
+ ```
37
+
38
+ This will install:
39
+ - All existing dependencies (PyTorch, sentence-transformers, etc.)
40
+ - **Gradio** - For the web UI (recommended)
41
+ - **Streamlit** - Alternative web UI framework
42
+
43
+ ### 4️⃣ Prepare Data (If Not Already Done)
44
+
45
+ ```powershell
46
+ python -m adapters.build_corpora
47
+ ```
48
+
49
+ This creates the necessary corpus files from PubMed and Miriad databases.
50
+
51
+ ### 5️⃣ Launch the Web UI
52
+
53
+ **Option A: Gradio (Recommended)**
54
+ ```powershell
55
+ python app.py
56
+ ```
57
+ Then open: http://127.0.0.1:7860
58
+
59
+ **Option B: Streamlit (Alternative)**
60
+ ```powershell
61
+ streamlit run app_streamlit.py
62
+ ```
63
+ Then open: http://localhost:8501
64
+
65
+ ## 🎯 Choose Your UI
66
+
67
+ ### Gradio (`app.py`)
68
+ ✅ Clean, modern interface
69
+ ✅ Dual-view (Formatted HTML + JSON)
70
+ ✅ Easy to share (can create public links)
71
+ ✅ Automatic API generation
72
+ ✅ Great for ML demos
73
+
74
+ ### Streamlit (`app_streamlit.py`)
75
+ ✅ More interactive and customizable
76
+ ✅ Sidebar with settings
77
+ ✅ Real-time updates
78
+ ✅ Better for data science apps
79
+ ✅ More widgets and components
80
+
81
+ **We recommend starting with Gradio!** It's simpler and looks very professional.
82
+
83
+ ## 🧪 Test with Example Queries
84
+
85
+ Try these queries to see the system in action:
86
+
87
+ 1. **Medical Query:**
88
+ > "I'm having a really bad rash on my hands. I'm pretty sure it's my eczema flaring up. Is there anything stronger than aquaphor I can use on it?"
89
+
90
+ 2. **Medical Emergency:**
91
+ > "worst headache of my life with fever and stiff neck"
92
+
93
+ 3. **Vaccine Question:**
94
+ > "I'm traveling to South America soon. Do I need to get any vaccines before I go?"
95
+
96
+ 4. **Administrative Query:**
97
+ > "Hey is there any way I can get an appointment in the next month?"
98
+
99
+ ## 🎨 UI Features
100
+
101
+ ### Classification
102
+ - Shows whether query is medical, administrative, or other
103
+ - Displays confidence scores for each category
104
+ - Visual progress bars or charts
105
+
106
+ ### Document Retrieval (Medical Queries Only)
107
+ - Retrieves top N relevant documents
108
+ - Shows BM25, Dense, and RRF scores
109
+ - Displays document title, text preview, and metadata
110
+ - Toggle between formatted view and raw JSON
111
+
112
+ ### Settings
113
+ - **Number of Results:** 1-50 documents
114
+ - **Use Reranker:** Enable for better accuracy (slower)
115
+
116
+ ## 🔧 Troubleshooting
117
+
118
+ ### "No module named 'gradio'"
119
+ ```powershell
120
+ pip install gradio
121
+ ```
122
+
123
+ ### "No corpora files found"
124
+ ```powershell
125
+ python -m adapters.build_corpora
126
+ ```
127
+
128
+ ### Port Already in Use
129
+ Edit `app.py` and change the port:
130
+ ```python
131
+ demo.launch(server_port=8080) # Change 7860 to 8080
132
+ ```
133
+
134
+ ### Models Not Loading
135
+ Make sure you have your HuggingFace token configured in `env.list`:
136
+ ```
137
+ HF_TOKEN="your-huggingface-token"
138
+ ```
139
+
140
+ ## 📊 Advanced Options
141
+
142
+ ### Share Publicly (Gradio)
143
+ Edit `app.py`, line ~255:
144
+ ```python
145
+ demo.launch(share=True) # Creates a 72-hour public link
146
+ ```
147
+
148
+ ### Add Authentication (Gradio)
149
+ ```python
150
+ demo.launch(auth=("username", "password"))
151
+ ```
152
+
153
+ ### Change Theme (Streamlit)
154
+ Create `.streamlit/config.toml`:
155
+ ```toml
156
+ [theme]
157
+ primaryColor = "#667eea"
158
+ backgroundColor = "#ffffff"
159
+ secondaryBackgroundColor = "#f0f2f6"
160
+ ```
161
+
162
+ ## 📱 Accessing from Other Devices
163
+
164
+ To access from other devices on your network:
165
+
166
+ 1. Find your IP address:
167
+ ```powershell
168
+ ipconfig
169
+ ```
170
+ Look for "IPv4 Address" (e.g., 192.168.1.100)
171
+
172
+ 2. Edit `app.py`:
173
+ ```python
174
+ demo.launch(server_name="0.0.0.0", server_port=7860)
175
+ ```
176
+
177
+ 3. Access from other device:
178
+ ```
179
+ http://192.168.1.100:7860
180
+ ```
181
+
182
+ ## 🎓 For Your Group Presentation
183
+
184
+ ### Demo Tips:
185
+ 1. Start with the interface loaded beforehand
186
+ 2. Have example queries ready
187
+ 3. Show both medical and administrative classification
188
+ 4. Demonstrate the reranker toggle
189
+ 5. Show both formatted and JSON views
190
+ 6. Explain the confidence scores
191
+
192
+ ### Screenshots to Take:
193
+ - Main interface
194
+ - Classification results
195
+ - Document retrieval results
196
+ - Settings panel
197
+ - Example queries
198
+
199
+ ### Key Points to Mention:
200
+ - Built with modern Python web frameworks
201
+ - Real-time classification and retrieval
202
+ - Hybrid search (BM25 + Dense embeddings)
203
+ - Optional reranking for accuracy
204
+ - Clean, professional interface
205
+
206
+ ## 📝 Next Steps
207
+
208
+ Once you're comfortable with the UI, you can:
209
+ - Customize the styling (CSS in `app.py`)
210
+ - Add more example queries
211
+ - Integrate with other systems via the API
212
+ - Deploy to cloud (Hugging Face Spaces, AWS, etc.)
213
+
214
+ ## 🤝 Team Credits
215
+
216
+ Display proudly on the interface:
217
+ - David Gray
218
+ - Tarak Jha
219
+ - Sravani Segireddy
220
+ - Riley Millikan
221
+ - Kent R. Spillner
222
+
223
+ ---
224
+
225
+ **Need help?** Check the full documentation in `UI_README.md` or ask your team members!
README.md CHANGED
@@ -1,12 +1,81 @@
1
- ---
2
- title: Medical Document Retrieval
3
- emoji: 📈
4
- colorFrom: red
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 6.0.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Medical_Document_Retrieval
3
+ app_file: app_retrieval_cached.py
4
+ sdk: gradio
5
+ sdk_version: 6.0.2
6
+ ---
7
+ # Health Query Classifier & Research Retriever
8
+
9
+ ## Team Members
10
+ * **David Gray**
11
+ * **Tarak Jha**
12
+ * **Sravani Segireddy**
13
+ * **Riley Millikan**
14
+ * **Kent R. Spillner**
15
+
16
+ ## Project Description
17
+ This project is a classifier that triages patient queries. If a query is identified as medical, the system retrieves relevant research and presents it to the user.
18
+
19
+ ## Workflow
20
+ The system operates in two main stages to optimize patient care and provider efficiency:
21
+
22
+ 1. **Classification (Triage)**:
23
+ The tool analyzes the user's input to determine if it is a medical query (requiring clinical attention) or an administrative query (scheduling, billing, etc.).
24
+
25
+ 2. **Research Retrieval**:
26
+ If the query is classified as medical, the system searches through indexed medical databases (like PubMed and Miriad) to retrieve relevant research articles and Q/A pairs. This empowers the patient with trustworthy information and provides the doctor with context.
27
+
28
+ ### Training Script
29
+
30
+ ```bash
31
+ python3 -m classifier.train
32
+ ```
33
+
34
+ ## Running the System Locally
35
+
36
+ ### Prerequisites
37
+ * Git
38
+ * Python 3
39
+
40
+ ### Setup & Configuration
41
+
42
+ 1. **Clone the repository**
43
+
44
+ ```bash
45
+ git clone https://github.com/davidgraymi/health-query-classifier.git
46
+ cd health-query-classifier
47
+ ```
48
+
49
+ 2. **Configure environment variables**
50
+
51
+ This project uses an `env.list` file for configuration. Create this file in the root directory.
52
+ ```ini
53
+ # env.list
54
+ HF_TOKEN="your-huggingface-token"
55
+ ```
56
+ * **HF_TOKEN**: Access token can be generated via [huggingface](https://huggingface.co/settings/tokens). The token must have read permissions.
57
+
58
+ 3. **Create a python virtual environment**
59
+
60
+ ```bash
61
+ python3 -m venv .venv
62
+ source .venv/bin/activate
63
+ ```
64
+
65
+ 4. **Install dependencies**
66
+
67
+ ```bash
68
+ pip install -r requirements.txt
69
+ ```
70
+
71
+ ### Data Setup
72
+
73
+ ```bash
74
+ python3 adapters/build_corpora.py
75
+ ```
76
+
77
+ ### Execution
78
+
79
+ ```bash
80
+ python3 main.py
81
+ ```
UI_GUIDE.md ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Summary: Available UI Options for Medical Q&A Bot
2
+
3
+ ## 🎯 Three Versions Created
4
+
5
+ ### 1. **app_demo.py** ⚡ RECOMMENDED FOR DEMOS
6
+ **Port:** 7863
7
+ **Speed:** Instant (<1 second)
8
+ **Features:**
9
+ - ✅ Real-time classification (medical vs administrative)
10
+ - ✅ Confidence scores with visualization
11
+ - ✅ Action recommendations
12
+ - ✅ Uses your group's trained models
13
+ - ❌ No document retrieval (for speed)
14
+
15
+ **Best for:**
16
+ - Class presentations
17
+ - Quick demonstrations
18
+ - Testing classification accuracy
19
+ - When time is limited
20
+
21
+ **Run:** `python app_demo.py`
22
+
23
+ ---
24
+
25
+ ### 2. **app_full.py** 🔬 COMPLETE SYSTEM
26
+ **Port:** 7864
27
+ **Speed:** First query: 6-10 minutes, subsequent: 2-5 seconds
28
+ **Features:**
29
+ - ✅ Real-time classification
30
+ - ✅ **Full document retrieval** from PubMed & Miriad
31
+ - ✅ BM25 + Dense search + RRF fusion
32
+ - ✅ Optional cross-encoder reranking
33
+ - ⚠️ Very slow first initialization
34
+
35
+ **Best for:**
36
+ - Showing full system capabilities
37
+ - When you have 10+ minutes to wait
38
+ - Detailed technical demonstrations
39
+ - Proving retrieval works
40
+
41
+ **Run:** `python app_full.py`
42
+
43
+ **⚠️ WARNING:** First medical query takes 6-10 minutes because:
44
+ - Loads ~200MB+ of medical corpus data
45
+ - Builds BM25 keyword index
46
+ - Generates embeddings for ALL documents (this is the slow part)
47
+ - Builds FAISS vector index
48
+
49
+ ---
50
+
51
+ ### 3. **app.py** / **app_safe.py** / **app_lightweight.py** 🔧 EXPERIMENTAL
52
+ These were intermediate versions created while troubleshooting.
53
+ **Not recommended for use.**
54
+
55
+ ---
56
+
57
+ ## 🎬 Recommendation for Your Group Presentation
58
+
59
+ ### Strategy 1: Fast Demo (5 minutes)
60
+ Use **`app_demo.py`** only:
61
+ 1. Show classification working instantly
62
+ 2. Test medical vs administrative queries
63
+ 3. Highlight confidence scores
64
+ 4. Explain that retrieval is available but disabled for demo speed
65
+ 5. Show the codebase that supports retrieval (team/candidates.py)
66
+
67
+ **Advantage:** Reliable, professional, no waiting
68
+
69
+ ---
70
+
71
+ ### Strategy 2: Split Demo (15+ minutes)
72
+ Use BOTH versions:
73
+
74
+ **Part 1:** Use `app_demo.py` for quick classification demos (5 min)
75
+ - Show multiple queries rapidly
76
+ - Demonstrate accuracy
77
+
78
+ **Part 2:** Switch to `app_full.py` that you pre-initialized (10 min)
79
+ - **Before presentation:** Run `app_full.py` and make ONE medical query to initialize
80
+ - Wait the 10 minutes for it to build indexes
81
+ - Keep it running
82
+ - During presentation: Show actual document retrieval working fast
83
+
84
+ **Advantage:** Shows both speed AND capabilities
85
+
86
+ ---
87
+
88
+ ### Strategy 3: Video Backup
89
+ 1. Use `app_demo.py` for live demo
90
+ 2. Record a video of `app_full.py` working with retrieval
91
+ 3. Show video during presentation if needed
92
+
93
+ ---
94
+
95
+ ## 📊 Technical Details to Mention
96
+
97
+ ### Your Group's Implementation:
98
+ - **Classification Model:** Fine-tuned sentence-transformers (embeddinggemma-300m-medical)
99
+ - **Hybrid Retrieval:** BM25 (sparse) + Dense embeddings (semantic)
100
+ - **Fusion Algorithm:** Reciprocal Rank Fusion (RRF)
101
+ - **Data Sources:** PubMed Medical Q&A + Miriad corpus
102
+ - **Optional Enhancement:** Cross-encoder reranker for accuracy
103
+
104
+ ### Why Retrieval is Slow:
105
+ - Real ML systems need to index large datasets
106
+ - Your corpus has thousands of medical documents
107
+ - CPU-only inference (no GPU acceleration available)
108
+ - This is a REAL implementation, not a toy demo
109
+
110
+ ### Production Solutions:
111
+ - Pre-build and save indexes (don't rebuild each time)
112
+ - Use GPU for faster embedding
113
+ - Implement caching
114
+ - Deploy on cloud with more resources
115
+
116
+ ---
117
+
118
+ ## 💡 Demo Script Suggestions
119
+
120
+ ### Opening (30 seconds):
121
+ "We built an AI system that automatically classifies patient queries and retrieves relevant medical research. Let me show you how it works..."
122
+
123
+ ### Classification Demo (2-3 minutes):
124
+ "First, our classification system determines if a query is medical or administrative..."
125
+ [Use app_demo.py, try 3-4 different queries]
126
+
127
+ ### Technical Explanation (2 minutes):
128
+ "Under the hood, we use:
129
+ - A fine-tuned 300-million parameter transformer model
130
+ - Hybrid search combining keyword matching and semantic similarity
131
+ - Reciprocal Rank Fusion to combine results
132
+ - Medical corpora from PubMed and Miriad databases"
133
+
134
+ ### Show Retrieval (Optional, if pre-initialized):
135
+ "Now let me show you actual document retrieval..."
136
+ [Use app_full.py if you pre-initialized it]
137
+
138
+ ### Closing (30 seconds):
139
+ "This demonstrates how AI can improve healthcare triage, reduce response times, and provide evidence-based information to both patients and providers."
140
+
141
+ ---
142
+
143
+ ## 🚀 Quick Start Commands
144
+
145
+ ```powershell
146
+ # For demos and presentations
147
+ python app_demo.py
148
+ # Access at: http://127.0.0.1:7863
149
+
150
+ # For full system (wait 10 minutes after first query)
151
+ python app_full.py
152
+ # Access at: http://127.0.0.1:7864
153
+ ```
154
+
155
+ ---
156
+
157
+ ## ✅ What You Successfully Built
158
+
159
+ 1. ✅ Working web UI with professional design
160
+ 2. ✅ Real-time classification using your trained model
161
+ 3. ✅ Full retrieval system integrated
162
+ 4. ✅ Two versions: fast demo + complete system
163
+ 5. ✅ Comprehensive documentation
164
+ 6. ✅ Example queries
165
+ 7. ✅ Clear visualization of results
166
+
167
+ **You have everything you need for a successful presentation!**
168
+
169
+ ---
170
+
171
+ ## 🎯 Final Recommendation
172
+
173
+ **For your presentation, use `app_demo.py`**
174
+
175
+ It shows your ML work instantly and professionally. You can explain:
176
+ - "The classification happens in real-time"
177
+ - "The full system includes retrieval which we can show separately"
178
+ - "This demonstrates the core AI capability"
179
+
180
+ If anyone asks about retrieval, you can:
181
+ - Show the code in `team/candidates.py`
182
+ - Explain the hybrid search architecture
183
+ - Mention it's fully implemented but slow due to index building
184
+
185
+ **This is the smart approach for a live demo!**
UI_IMPLEMENTATION.md ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Medical Q&A Bot - UI Implementation Summary
2
+
3
+ ## 📦 What Was Added
4
+
5
+ ### New Files Created:
6
+
7
+ 1. **`app.py`** - Main Gradio web interface (RECOMMENDED)
8
+ - Clean, modern UI with gradient header
9
+ - Dual-view mode (formatted HTML + JSON)
10
+ - Real-time classification and retrieval
11
+ - Example queries built-in
12
+ - Automatic API generation
13
+
14
+ 2. **`app_streamlit.py`** - Alternative Streamlit interface
15
+ - Interactive sidebar with settings
16
+ - Card-based result display
17
+ - Progress bars for confidence scores
18
+ - More customizable styling options
19
+
20
+ 3. **`QUICKSTART.md`** - Step-by-step setup guide
21
+ - Installation instructions
22
+ - How to launch both UIs
23
+ - Troubleshooting tips
24
+ - Demo tips for presentations
25
+
26
+ 4. **`UI_README.md`** - Comprehensive documentation
27
+ - Feature descriptions
28
+ - Configuration options
29
+ - Advanced usage
30
+ - API information
31
+
32
+ 5. **`setup_ui.py`** - Automated setup script
33
+ - Installs all dependencies
34
+ - Checks for required data files
35
+ - Verifies setup completeness
36
+
37
+ ### Modified Files:
38
+
39
+ 1. **`requirements.txt`** - Added:
40
+ - `gradio` - Main UI framework
41
+ - `streamlit` - Alternative UI framework
42
+
43
+ ## 🎯 Key Features
44
+
45
+ ### Classification Display
46
+ - ✅ Shows query type (Medical/Administrative/Other)
47
+ - ✅ Confidence scores with visual indicators
48
+ - ✅ Color-coded results
49
+
50
+ ### Document Retrieval
51
+ - ✅ Retrieves top N relevant documents (1-50)
52
+ - ✅ Shows BM25, Dense, and RRF scores
53
+ - ✅ Displays document preview with full metadata
54
+ - ✅ Optional reranker for better accuracy
55
+
56
+ ### User Experience
57
+ - ✅ Clean, professional design
58
+ - ✅ Example queries for easy testing
59
+ - ✅ Formatted and raw JSON views
60
+ - ✅ Responsive layout
61
+ - ✅ Real-time processing
62
+
63
+ ## 🚀 Quick Start (TL;DR)
64
+
65
+ ```powershell
66
+ # 1. Install dependencies
67
+ pip install -r requirements.txt
68
+
69
+ # 2. Build data (if needed)
70
+ python -m adapters.build_corpora
71
+
72
+ # 3. Launch UI (choose one)
73
+ python app.py # Gradio (recommended)
74
+ streamlit run app_streamlit.py # Streamlit (alternative)
75
+ ```
76
+
77
+ ## 🌟 Which UI to Use?
78
+
79
+ ### Use Gradio (`app.py`) if you want:
80
+ - Quick setup and deployment
81
+ - Clean, minimal interface
82
+ - Easy sharing (public links)
83
+ - Automatic REST API
84
+ - Better for demos and presentations
85
+
86
+ ### Use Streamlit (`app_streamlit.py`) if you want:
87
+ - More interactive controls
88
+ - Sidebar with settings
89
+ - More customization options
90
+ - Data-science focused interface
91
+
92
+ **Recommendation:** Start with **Gradio** - it's simpler and looks very professional!
93
+
94
+ ## 📊 How It Works
95
+
96
+ ```
97
+ User Input → Gradio/Streamlit Interface
98
+
99
+ Classifier (classify query as medical/admin/other)
100
+
101
+ If Medical → Retriever (BM25 + Dense + RRF)
102
+
103
+ Optional Reranker (cross-encoder)
104
+
105
+ Display Results (formatted cards + JSON)
106
+ ```
107
+
108
+ ## 🎓 For Your Presentation
109
+
110
+ ### Demo Flow:
111
+ 1. **Open the interface** - Show the clean design
112
+ 2. **Enter a medical query** - e.g., "I have a rash on my hands"
113
+ 3. **Show classification** - Highlight confidence scores
114
+ 4. **Display results** - Show retrieved medical documents
115
+ 5. **Try administrative query** - Show different handling
116
+ 6. **Toggle settings** - Demonstrate reranker and result count
117
+ 7. **Show JSON view** - For technical audience
118
+
119
+ ### Key Talking Points:
120
+ - "We built a user-friendly web interface using Gradio"
121
+ - "The system classifies queries in real-time with confidence scores"
122
+ - "For medical queries, it retrieves relevant research from PubMed and Miriad"
123
+ - "Uses hybrid search combining BM25 and dense embeddings"
124
+ - "Optional reranker for improved accuracy"
125
+
126
+ ## 🎨 Customization Options
127
+
128
+ ### Change Theme/Colors
129
+ Edit the CSS in `app.py`:
130
+ ```python
131
+ custom_css = """
132
+ .header {
133
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
134
+ }
135
+ """
136
+ ```
137
+
138
+ ### Add More Example Queries
139
+ Edit the `examples` list in `app.py`:
140
+ ```python
141
+ gr.Examples(
142
+ examples=[
143
+ ["Your custom query here"],
144
+ ],
145
+ inputs=query_input,
146
+ )
147
+ ```
148
+
149
+ ### Change Port
150
+ ```python
151
+ demo.launch(server_port=8080) # Default is 7860
152
+ ```
153
+
154
+ ### Enable Public Sharing
155
+ ```python
156
+ demo.launch(share=True) # Creates 72-hour public link
157
+ ```
158
+
159
+ ## 🔧 Technical Details
160
+
161
+ ### Gradio Version
162
+ - Framework: Gradio 4.x
163
+ - Features: Blocks API, custom CSS, examples, tabs
164
+ - Auto-generates REST API at `/docs`
165
+
166
+ ### Streamlit Version
167
+ - Framework: Streamlit
168
+ - Features: Sidebar, caching, progress bars, metrics
169
+ - More suitable for data exploration
170
+
171
+ ### Integration
172
+ - Uses existing classifier and retriever modules
173
+ - No changes to core logic
174
+ - Models loaded once and cached
175
+ - Async processing for better UX
176
+
177
+ ## 📁 File Structure
178
+
179
+ ```
180
+ health-query-classifier/
181
+ ├── app.py # ← Main Gradio UI
182
+ ├── app_streamlit.py # ← Alternative Streamlit UI
183
+ ├── setup_ui.py # ← Setup script
184
+ ├── QUICKSTART.md # ← Quick start guide
185
+ ├── UI_README.md # ← Detailed documentation
186
+ ├── requirements.txt # ← Updated with gradio/streamlit
187
+ ├── classifier/
188
+ │ ├── infer.py # Used by UI
189
+ │ └── ...
190
+ ├── retriever/
191
+ │ └── ...
192
+ └── team/
193
+ ├── candidates.py # Used by UI
194
+ └── ...
195
+ ```
196
+
197
+ ## 🐛 Common Issues & Solutions
198
+
199
+ ### "Module not found: gradio"
200
+ ```powershell
201
+ pip install gradio
202
+ ```
203
+
204
+ ### "No corpora files found"
205
+ ```powershell
206
+ python -m adapters.build_corpora
207
+ ```
208
+
209
+ ### Models take long to load
210
+ - This is normal on first run
211
+ - Models are cached after initial load
212
+ - Consider using smaller models for faster demo
213
+
214
+ ### Port already in use
215
+ - Change port in `app.py` (line ~255)
216
+ - Or kill the process using that port
217
+
218
+ ## 🌐 Deployment Options
219
+
220
+ ### Local (Current Setup)
221
+ - Best for development and demos
222
+ - Access via localhost
223
+
224
+ ### Hugging Face Spaces (Free)
225
+ - Free hosting for Gradio apps
226
+ - Easy to deploy
227
+ - Public URL
228
+
229
+ ### Cloud Platforms
230
+ - AWS, Google Cloud, Azure
231
+ - More control and scalability
232
+ - Requires more setup
233
+
234
+ ## 📈 Future Enhancements
235
+
236
+ Potential additions for future development:
237
+ - User authentication
238
+ - Query history
239
+ - Bookmarking results
240
+ - Export to PDF
241
+ - Multi-language support
242
+ - Voice input
243
+ - Mobile app
244
+ - Analytics dashboard
245
+
246
+ ## 👥 Team Contributions
247
+
248
+ This UI implementation demonstrates:
249
+ - Full-stack development skills
250
+ - ML model integration
251
+ - User experience design
252
+ - Modern web frameworks
253
+ - Professional documentation
254
+
255
+ Perfect for showcasing in your group project presentation!
256
+
257
+ ## 📞 Support
258
+
259
+ For issues or questions:
260
+ 1. Check `QUICKSTART.md` for setup issues
261
+ 2. Check `UI_README.md` for feature documentation
262
+ 3. Review error messages carefully
263
+ 4. Contact team members
264
+
265
+ ---
266
+
267
+ ## ✅ Ready to Present!
268
+
269
+ Your medical Q&A bot now has a professional web interface that:
270
+ - ✅ Looks modern and clean
271
+ - ✅ Is easy to use
272
+ - ✅ Demonstrates your ML capabilities
273
+ - ✅ Provides clear results
274
+ - ✅ Is well-documented
275
+ - ✅ Can be easily deployed
276
+
277
+ **Great work, team!** 🎉
278
+
279
+ ---
280
+
281
+ *Created by: Tarak Jha, with contributions from the entire team*
282
+ *Team: David Gray • Tarak Jha • Sravani Segireddy • Riley Millikan • Kent R. Spillner*
UI_README.md ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Medical Q&A Bot - Web UI
2
+
3
+ This is a user-friendly web interface for the Health Query Classifier & Research Retriever system.
4
+
5
+ ## Features
6
+
7
+ ✨ **Clean, Modern Interface** - Built with Gradio for an intuitive user experience
8
+
9
+ 🎯 **Query Classification** - Automatically triages queries as medical or administrative with confidence scores
10
+
11
+ 📚 **Intelligent Retrieval** - Retrieves relevant medical research from PubMed and Miriad databases
12
+
13
+ 🔍 **Dual View Modes** - View results in formatted HTML or raw JSON
14
+
15
+ ⚙️ **Customizable Settings** - Adjust number of results and toggle reranker for better accuracy
16
+
17
+ ## Quick Start
18
+
19
+ ### 1. Install Dependencies
20
+
21
+ Make sure you have the updated requirements installed:
22
+
23
+ ```bash
24
+ pip install -r requirements.txt
25
+ ```
26
+
27
+ This will install Gradio along with all other dependencies.
28
+
29
+ ### 2. Prepare Data
30
+
31
+ If you haven't already, build the corpora:
32
+
33
+ ```bash
34
+ python -m adapters.build_corpora
35
+ ```
36
+
37
+ ### 3. Launch the Web UI
38
+
39
+ ```bash
40
+ python app.py
41
+ ```
42
+
43
+ The interface will be available at: **http://127.0.0.1:7860**
44
+
45
+ ## Using the Interface
46
+
47
+ 1. **Enter Your Query** - Type your health-related question in the text box
48
+ 2. **Adjust Settings** (optional):
49
+ - **Number of Results**: Control how many documents to retrieve (1-50)
50
+ - **Use Reranker**: Enable for more accurate results (slower)
51
+ 3. **Click "Analyze Query"** or press Enter
52
+ 4. **View Results**:
53
+ - **Classification**: See how your query was categorized
54
+ - **Formatted View**: Readable cards with document information
55
+ - **JSON View**: Raw data for technical analysis
56
+
57
+ ## Example Queries
58
+
59
+ Try these example queries to see the system in action:
60
+
61
+ - "I'm having a really bad rash on my hands. Is there anything stronger than aquaphor I can use?"
62
+ - "I'm traveling to South America soon. Do I need to get any vaccines before I go?"
63
+ - "worst headache of my life with fever and stiff neck"
64
+ - "Hey is there any way I can get an appointment in the next month?"
65
+
66
+ ## Configuration Options
67
+
68
+ ### Sharing Your Interface
69
+
70
+ To create a public shareable link (72 hours), modify `app.py`:
71
+
72
+ ```python
73
+ demo.launch(
74
+ share=True, # Creates a public link
75
+ server_name="127.0.0.1",
76
+ server_port=7860,
77
+ )
78
+ ```
79
+
80
+ ### Custom Port
81
+
82
+ Change the port if 7860 is already in use:
83
+
84
+ ```python
85
+ demo.launch(
86
+ server_port=8080, # Use your preferred port
87
+ )
88
+ ```
89
+
90
+ ### Authentication
91
+
92
+ Add password protection:
93
+
94
+ ```python
95
+ demo.launch(
96
+ auth=("username", "password"), # Simple auth
97
+ )
98
+ ```
99
+
100
+ ## Architecture
101
+
102
+ The UI integrates with your existing codebase:
103
+
104
+ ```
105
+ User Query → Gradio Interface → Classifier → Retriever → Results Display
106
+ ```
107
+
108
+ - **Frontend**: Gradio (Python-based web framework)
109
+ - **Classification**: Uses your trained classifier model
110
+ - **Retrieval**: Hybrid search (BM25 + Dense embeddings + RRF)
111
+ - **Reranking**: Optional cross-encoder reranker
112
+
113
+ ## Troubleshooting
114
+
115
+ ### Models Not Loading
116
+
117
+ Ensure you have the classifier checkpoint and data files:
118
+ ```bash
119
+ ls -la data/corpora/
120
+ ```
121
+
122
+ ### Port Already in Use
123
+
124
+ Change the port number in `app.py` or kill the process using port 7860.
125
+
126
+ ### Gradio Import Error
127
+
128
+ Make sure Gradio is installed:
129
+ ```bash
130
+ pip install gradio
131
+ ```
132
+
133
+ ### No Medical Documents Found
134
+
135
+ Verify that corpora files exist in `data/corpora/`:
136
+ - `medical_qa.jsonl`
137
+ - `miriad_text.jsonl`
138
+ - `unidoc_qa.jsonl`
139
+
140
+ Run the build script if missing:
141
+ ```bash
142
+ python -m adapters.build_corpora
143
+ ```
144
+
145
+ ## Advanced Features
146
+
147
+ ### API Mode
148
+
149
+ Gradio automatically creates a REST API alongside the web UI. Access the API docs at:
150
+ ```
151
+ http://127.0.0.1:7860/docs
152
+ ```
153
+
154
+ ### Embedding the Interface
155
+
156
+ You can embed the Gradio interface in other web applications using iframes:
157
+
158
+ ```html
159
+ <iframe src="http://127.0.0.1:7860" width="100%" height="800px"></iframe>
160
+ ```
161
+
162
+ ## Team
163
+
164
+ - **David Gray**
165
+ - **Tarak Jha**
166
+ - **Sravani Segireddy**
167
+ - **Riley Millikan**
168
+ - **Kent R. Spillner**
169
+
170
+ ## License
171
+
172
+ See the main README.md for project license information.
UI_SUMMARY.md ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎉 Medical Q&A Bot - UI Implementation Complete!
2
+
3
+ ## ✅ What You Now Have
4
+
5
+ ### 🖥️ Two Complete Web Interfaces
6
+
7
+ #### 1. **Gradio Interface** (`app.py`) - RECOMMENDED ⭐
8
+ - Clean, modern design with gradient styling
9
+ - Dual-view mode (Formatted + JSON)
10
+ - Built-in example queries
11
+ - Easy to share and deploy
12
+ - Automatic REST API generation
13
+ - Launch with: `python app.py`
14
+ - Access at: http://127.0.0.1:7860
15
+
16
+ #### 2. **Streamlit Interface** (`app_streamlit.py`) - ALTERNATIVE
17
+ - Interactive sidebar with live controls
18
+ - Card-based result display
19
+ - Progress bars and metrics
20
+ - More customization options
21
+ - Launch with: `streamlit run app_streamlit.py`
22
+ - Access at: http://localhost:8501
23
+
24
+ ---
25
+
26
+ ## 📚 Complete Documentation Suite
27
+
28
+ ### Quick Start Guides
29
+ 1. **QUICKSTART.md** - Step-by-step setup (5 minutes)
30
+ 2. **launch_ui.bat** - Windows batch launcher (double-click to run)
31
+ 3. **launch_ui.ps1** - PowerShell launcher (right-click → Run with PowerShell)
32
+ 4. **setup_ui.py** - Automated setup script
33
+
34
+ ### Comprehensive Documentation
35
+ 1. **UI_README.md** - Complete UI feature documentation
36
+ 2. **UI_IMPLEMENTATION.md** - Implementation details and summary
37
+ 3. **ARCHITECTURE.md** - System architecture with diagrams
38
+ 4. **PRESENTATION_SCRIPT.md** - Complete presentation guide with demo script
39
+
40
+ ---
41
+
42
+ ## 🚀 How to Get Started (3 Easy Steps)
43
+
44
+ ### Option 1: PowerShell Launcher (Easiest!)
45
+ ```powershell
46
+ # Just double-click or run:
47
+ .\launch_ui.ps1
48
+ ```
49
+
50
+ ### Option 2: Command Line
51
+ ```powershell
52
+ # 1. Install dependencies
53
+ pip install -r requirements.txt
54
+
55
+ # 2. Build data (if needed)
56
+ python -m adapters.build_corpora
57
+
58
+ # 3. Launch!
59
+ python app.py
60
+ ```
61
+
62
+ ### Option 3: Setup Script
63
+ ```powershell
64
+ # Run the automated setup
65
+ python setup_ui.py
66
+
67
+ # Then launch
68
+ python app.py
69
+ ```
70
+
71
+ ---
72
+
73
+ ## 🎨 Key Features
74
+
75
+ ### Classification
76
+ ✅ Automatic query classification (Medical/Administrative/Other)
77
+ ✅ Confidence scores for transparency
78
+ ✅ Visual indicators and progress bars
79
+ ✅ Color-coded results
80
+
81
+ ### Retrieval
82
+ ✅ Hybrid search (BM25 + Dense + RRF)
83
+ ✅ Retrieves from PubMed, Miriad, and UniDoc
84
+ ✅ Adjustable number of results (1-50)
85
+ ✅ Optional cross-encoder reranking
86
+ ✅ Multiple relevance scores per document
87
+
88
+ ### User Experience
89
+ ✅ Clean, professional interface
90
+ ✅ Example queries built-in
91
+ ✅ Real-time processing
92
+ ✅ Formatted and JSON view modes
93
+ ✅ Mobile-responsive design
94
+ ✅ Error handling and validation
95
+
96
+ ---
97
+
98
+ ## 📁 New Files Created
99
+
100
+ ```
101
+ health-query-classifier/
102
+ ├── 🌐 Web Interfaces
103
+ │ ├── app.py ⭐ (Main Gradio UI)
104
+ │ ├── app_streamlit.py (Alternative Streamlit UI)
105
+ │ ├── launch_ui.bat (Windows launcher)
106
+ │ └── launch_ui.ps1 (PowerShell launcher)
107
+
108
+ ├── 📚 Documentation
109
+ │ ├── QUICKSTART.md (5-minute setup guide)
110
+ │ ├── UI_README.md (Feature documentation)
111
+ │ ├── UI_IMPLEMENTATION.md (Technical summary)
112
+ │ ├── ARCHITECTURE.md (System diagrams)
113
+ │ ├── PRESENTATION_SCRIPT.md (Demo script)
114
+ │ └── UI_SUMMARY.md (This file)
115
+
116
+ ├── 🔧 Setup Tools
117
+ │ └── setup_ui.py (Automated installer)
118
+
119
+ └── 📦 Updated Files
120
+ └── requirements.txt (Added gradio + streamlit)
121
+ ```
122
+
123
+ ---
124
+
125
+ ## 🎯 What Each Interface Looks Like
126
+
127
+ ### Gradio Interface Features:
128
+ ```
129
+ ┌─────────────────────────────────────────┐
130
+ │ 🏥 Medical Q&A Bot │
131
+ │ Health Query Classifier & Retriever │
132
+ │ Team: David • Tarak • Sravani • etc. │
133
+ ├─────────────────────────────────────────┤
134
+ │ │
135
+ │ [Enter your health query...] │
136
+ │ ┌─────────────────────────────────┐ │
137
+ │ │ Number of Results: [10] │ │
138
+ │ │ ☐ Use Reranker │ │
139
+ │ └─────────────────────────────────┘ │
140
+ │ │
141
+ │ [🔍 Analyze Query] │
142
+ │ │
143
+ ├─────────────────────────────────────────┤
144
+ │ Classification Result: │
145
+ │ ✓ MEDICAL (95% confidence) │
146
+ │ - Medical: 95.2% ████████████ │
147
+ │ - Administrative: 4.8% █ │
148
+ ├─────────────────────────────────────────┤
149
+ │ [📄 Formatted View] [📊 JSON View] │
150
+ │ │
151
+ │ Found 10 Relevant Documents │
152
+ │ ┌───────────────────────────────┐ │
153
+ │ │ Result #1: Eczema Treatment │ │
154
+ │ │ BM25: 0.85 Dense: 0.92 RRF: 1.2│ │
155
+ │ │ Text: Treatment options for...│ │
156
+ │ └───────────────────────────────┘ │
157
+ │ ... │
158
+ └─────────────────────────────────────────┘
159
+ ```
160
+
161
+ ### Streamlit Interface Features:
162
+ ```
163
+ ┌─────────────┬───────────────────────────┐
164
+ │ ⚙️ Settings │ 🏥 Medical Q&A Bot │
165
+ │ │ ═══════════════════════ │
166
+ │ Results: 10 │ │
167
+ │ ▁▁▁▁▁▁▁▁ │ [Query input box...] │
168
+ │ │ │
169
+ │ ☐ Reranker │ [🔍 Analyze Query] │
170
+ │ │ │
171
+ │ ☐ JSON View │ Classification: │
172
+ │ │ 🏥 MEDICAL │
173
+ │ Examples: │ │
174
+ │ • Rash... │ Confidence: │
175
+ │ • Vaccine...│ ████████████ 95% │
176
+ │ • Headache..│ │
177
+ │ │ Results: │
178
+ │ │ ┌─────────────────┐ │
179
+ │ │ │ Result #1 │ │
180
+ │ │ │ BM25 │ Dense │ │
181
+ │ │ │ 0.85 │ 0.92 │ │
182
+ │ │ └─────────────────┘ │
183
+ └─────────────┴───────────────────────────┘
184
+ ```
185
+
186
+ ---
187
+
188
+ ## 💡 Demo Workflow
189
+
190
+ ### Perfect Demo Sequence:
191
+ 1. **Start** → Launch UI (`python app.py`)
192
+ 2. **Medical Query** → "I have a rash on my hands..."
193
+ - Show classification
194
+ - Show retrieved documents
195
+ - Point out scores
196
+ 3. **Admin Query** → "Can I get an appointment?"
197
+ - Show different classification
198
+ - No retrieval happens
199
+ 4. **Settings** → Adjust results, toggle reranker
200
+ 5. **Views** → Switch between formatted and JSON
201
+
202
+ ---
203
+
204
+ ## 🎓 For Your Presentation
205
+
206
+ ### Talking Points:
207
+ ✅ "We built a professional web interface using Gradio"
208
+ ✅ "The system classifies queries in real-time"
209
+ ✅ "For medical queries, it retrieves relevant research"
210
+ ✅ "Uses hybrid search with BM25 and dense embeddings"
211
+ ✅ "Optional reranking for improved accuracy"
212
+ ✅ "Clean, intuitive user experience"
213
+
214
+ ### What to Demo:
215
+ ✅ Classification confidence scores
216
+ ✅ Document retrieval results
217
+ ✅ Different query types (medical vs admin)
218
+ ✅ Settings adjustment (reranker, result count)
219
+ ✅ Multiple view modes (formatted + JSON)
220
+
221
+ ### Impressive Technical Details:
222
+ ✅ Sentence transformer embeddings
223
+ ✅ Neural network classifier
224
+ ✅ FAISS vector search
225
+ ✅ RRF fusion algorithm
226
+ ✅ Cross-encoder reranking
227
+ ✅ Professional UI framework
228
+
229
+ ---
230
+
231
+ ## 🛠️ Troubleshooting
232
+
233
+ ### Common Issues:
234
+
235
+ **"Module not found: gradio"**
236
+ ```powershell
237
+ pip install gradio
238
+ ```
239
+
240
+ **"No corpora files found"**
241
+ ```powershell
242
+ python -m adapters.build_corpora
243
+ ```
244
+
245
+ **"Port already in use"**
246
+ ```python
247
+ # Edit app.py line ~255
248
+ demo.launch(server_port=8080) # Change port
249
+ ```
250
+
251
+ **"Models loading slowly"**
252
+ - This is normal on first run
253
+ - Models are cached afterward
254
+ - Takes 30-60 seconds initially
255
+
256
+ ---
257
+
258
+ ## 🌟 Why This Implementation is Great
259
+
260
+ ### For Your Project:
261
+ ✅ Professional appearance
262
+ ✅ Easy to demonstrate
263
+ ✅ Well-documented
264
+ ✅ Production-ready foundation
265
+ ✅ Impressive to stakeholders
266
+
267
+ ### For Your Resume:
268
+ ✅ Modern tech stack (Gradio, PyTorch, Transformers)
269
+ ✅ Full-stack development (UI + ML backend)
270
+ ✅ Healthcare application (impactful domain)
271
+ ✅ Clean, maintainable code
272
+ ✅ Comprehensive documentation
273
+
274
+ ### For Future Development:
275
+ ✅ Easy to extend
276
+ ✅ Modular architecture
277
+ ✅ Multiple deployment options
278
+ ✅ API already available
279
+ ✅ Scalable design
280
+
281
+ ---
282
+
283
+ ## 📊 File Sizes
284
+
285
+ ```
286
+ app.py ~9 KB (Main UI)
287
+ app_streamlit.py ~7 KB (Alt UI)
288
+ QUICKSTART.md ~5 KB (Setup guide)
289
+ UI_README.md ~8 KB (Features)
290
+ UI_IMPLEMENTATION.md ~10 KB (Details)
291
+ ARCHITECTURE.md ~15 KB (Diagrams)
292
+ PRESENTATION_SCRIPT.md ~12 KB (Demo guide)
293
+ ```
294
+
295
+ Total new documentation: **~66 KB of helpful guides!**
296
+
297
+ ---
298
+
299
+ ## 🎯 Next Steps
300
+
301
+ ### Immediate (5 minutes):
302
+ 1. Run `pip install gradio streamlit`
303
+ 2. Launch UI with `python app.py`
304
+ 3. Test with example queries
305
+ 4. Familiarize yourself with features
306
+
307
+ ### Short Term (1 hour):
308
+ 1. Read through QUICKSTART.md
309
+ 2. Test both Gradio and Streamlit interfaces
310
+ 3. Prepare demo queries
311
+ 4. Practice presentation flow
312
+
313
+ ### Before Presentation:
314
+ 1. Review PRESENTATION_SCRIPT.md
315
+ 2. Test demo multiple times
316
+ 3. Prepare backup slides
317
+ 4. Assign team roles
318
+ 5. Get excited! 🚀
319
+
320
+ ---
321
+
322
+ ## 🤝 Team Credits
323
+
324
+ **Built by:**
325
+ - David Gray
326
+ - Tarak Jha
327
+ - Sravani Segireddy
328
+ - Riley Millikan
329
+ - Kent R. Spillner
330
+
331
+ **Technologies Used:**
332
+ - Python 3.8+
333
+ - PyTorch
334
+ - Sentence-Transformers
335
+ - Gradio
336
+ - Streamlit
337
+ - FAISS
338
+ - BM25
339
+ - scikit-learn
340
+
341
+ ---
342
+
343
+ ## 🎉 You're All Set!
344
+
345
+ Your medical Q&A bot now has:
346
+ ✅ Two professional web interfaces
347
+ ✅ Complete documentation
348
+ ✅ Easy launchers
349
+ ✅ Presentation guide
350
+ ✅ Demo script
351
+ ✅ Architecture diagrams
352
+
353
+ **Everything you need for a successful demo and presentation!**
354
+
355
+ ---
356
+
357
+ ## 🚀 Quick Commands Reference
358
+
359
+ ```powershell
360
+ # Install everything
361
+ pip install -r requirements.txt
362
+
363
+ # Build data
364
+ python -m adapters.build_corpora
365
+
366
+ # Launch Gradio UI (recommended)
367
+ python app.py
368
+
369
+ # Launch Streamlit UI (alternative)
370
+ streamlit run app_streamlit.py
371
+
372
+ # Run automated setup
373
+ python setup_ui.py
374
+
375
+ # Use launcher scripts
376
+ .\launch_ui.ps1
377
+ ```
378
+
379
+ ---
380
+
381
+ ## 📞 Need Help?
382
+
383
+ 1. Check QUICKSTART.md for setup issues
384
+ 2. Check UI_README.md for feature questions
385
+ 3. Check ARCHITECTURE.md for technical details
386
+ 4. Check PRESENTATION_SCRIPT.md for demo help
387
+ 5. Ask your team members!
388
+
389
+ ---
390
+
391
+ ## ✨ Final Notes
392
+
393
+ This implementation provides:
394
+ - **Professional quality** - Ready to show to professors, potential employers
395
+ - **Well-documented** - Easy for team members to understand
396
+ - **Extensible** - Can be built upon for future projects
397
+ - **Portfolio-worthy** - Great addition to your GitHub
398
+
399
+ **You've got an impressive project here. Go show it off! 🌟**
400
+
401
+ ---
402
+
403
+ *Created: December 3, 2025*
404
+ *For: Health Query Classifier Group Project*
405
+ *By: Your friendly AI assistant*
__init__.py ADDED
File without changes
adapters/build_corpora.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, jsonlines, pathlib
2
+ import concurrent.futures
3
+ from tqdm import tqdm
4
+ from datasets import load_dataset
5
+ from math import ceil
6
+ from pubmed import download_pubmed
7
+
8
+ OUT = pathlib.Path("data/corpora")
9
+ OUT.mkdir(parents=True, exist_ok=True)
10
+
11
+ PUBMED_ARTICLES_PER_XML_FILE = 30000
12
+
13
+ def write_jsonl(path, rows):
14
+ print(f"Writing {len(rows)} records to {path}")
15
+ with jsonlines.open(path, "w") as out:
16
+ out.write_all(rows)
17
+ print(f"Finished writing {path}")
18
+
19
+ # 1) LasseRegin medical Q&A
20
+ def build_lasseregin():
21
+ print("Starting LasseRegin build...")
22
+ import urllib.request
23
+ url = "https://raw.githubusercontent.com/LasseRegin/medical-question-answer-data/master/icliniqQAs.json"
24
+
25
+ try:
26
+ with urllib.request.urlopen(url) as response:
27
+ data = json.loads(response.read().decode("utf-8"))
28
+ except Exception as e:
29
+ print(f"Failed to download LasseRegin data: {e}")
30
+ return
31
+
32
+ rows = []
33
+ for i, r in enumerate(tqdm(data, desc="LasseRegin", leave=False)):
34
+ rows.append({
35
+ "id": f"icliniq:{i}",
36
+ "title": r.get("title",""),
37
+ "question": r.get("question",""),
38
+ "answer": r.get("answer",""),
39
+ "source": "icliniq"
40
+ })
41
+ write_jsonl(OUT / "medical_qa.jsonl", rows)
42
+ print("Completed LasseRegin build.")
43
+
44
+ # 2) MIRIAD-4.4M-split
45
+ def build_miriad(sample_size=200_000):
46
+ print(f"Starting MIRIAD build (sample_size={sample_size})...")
47
+ try:
48
+ ds = load_dataset("miriad/miriad-4.4M", num_proc=4, split="train")
49
+
50
+ ds = ds.shuffle(seed=42).select(range(min(sample_size, len(ds))))
51
+ except Exception as e:
52
+ print(f"Failed to load MIRIAD dataset: {e}")
53
+ return
54
+
55
+ rows = []
56
+ for i, ex in enumerate(tqdm(ds, desc="miriad", leave=False)):
57
+ rows.append({
58
+ "id": f"miriad:{i}",
59
+ "title": ex.get("paper_title",""),
60
+ "question": ex.get("question", ""),
61
+ "answer": ex.get("passage_text", ""),
62
+ "year": ex.get("year",""),
63
+ "specialty": ex.get("specialty",""),
64
+
65
+ })
66
+ write_jsonl(OUT / "miriad_text.jsonl", rows)
67
+ print("Completed MIRIAD build.")
68
+
69
+ # 3) PubMed abstracts
70
+ def build_pubmed(max_records=500_000):
71
+ num_files = int(ceil(max_records / PUBMED_ARTICLES_PER_XML_FILE))
72
+ print(f"Starting PubMed build (num_files={num_files}, max_records={max_records})...")
73
+
74
+ download_pubmed(OUT / "pubmed.jsonl", num_files)
75
+ print("Completed PubMed build.")
76
+
77
+ # 4) UniDoc-Bench (QA)
78
+ def build_unidoc(max_items=1000):
79
+ print(f"Starting UniDoc build (max_items={max_items})...")
80
+ try:
81
+ ds = load_dataset("Salesforce/UniDoc-Bench", split="healthcare")
82
+ except Exception as e:
83
+ print(f"Failed to load UniDoc dataset: {e}")
84
+ return
85
+
86
+ rows = []
87
+ for i, ex in enumerate(tqdm(ds, desc="unidoc", leave=False)):
88
+ q = ex.get("question","") or ex.get("query","")
89
+ a = ex.get("answer","") or ""
90
+ pdf = ex.get("pdf_path") or ex.get("document_path") or ""
91
+ domain = ex.get("domain","")
92
+ rows.append({
93
+ "id": f"unidoc:{i}",
94
+ "title": f"{domain} PDF",
95
+ "question": q,
96
+ "answer": a,
97
+ "pdf_path": pdf
98
+ })
99
+ if i+1 >= max_items:
100
+ break
101
+ write_jsonl(OUT / "unidoc_qa.jsonl", rows)
102
+ print("Completed UniDoc build.")
103
+
104
+ def main():
105
+ print("Starting parallel corpora build...")
106
+ # Define tasks
107
+ tasks = [
108
+ (build_lasseregin, []),
109
+ (build_miriad, [1000]),
110
+ (build_pubmed, [500_000]),
111
+
112
+ (build_unidoc, [1000])
113
+ ]
114
+
115
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
116
+ futures = [executor.submit(func, *args) for func, args in tasks]
117
+ for future in concurrent.futures.as_completed(futures):
118
+ try:
119
+ future.result()
120
+ except Exception as e:
121
+ print(f"A task failed: {e}")
122
+
123
+ print("✅ All corpora built successfully in data/corpora/")
124
+
125
+ if __name__ == "__main__":
126
+ main()
adapters/pubmed.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+
3
+ import os
4
+ import gzip
5
+ import json
6
+ import re
7
+ import subprocess
8
+ import xml.etree.ElementTree as ET
9
+
10
+ from tqdm import tqdm
11
+ from urllib.request import urlopen, urlretrieve
12
+
13
+
14
+ PUBMED_DATASET_BASE_URL = "https://ftp.ncbi.nlm.nih.gov/pubmed/baseline"
15
+
16
+ PUBMED_FILE_LIMIT = 10
17
+
18
+
19
+ def get_pubmed_dataset_size():
20
+ try:
21
+ with urlopen(PUBMED_DATASET_BASE_URL) as response:
22
+ html = response.read().decode("utf-8")
23
+
24
+ files = re.findall(r"(pubmed\d+n\d+)\.xml\.gz(?!\.)", html)
25
+ unique_files = set(files)
26
+
27
+ return len(unique_files)
28
+
29
+ except Exception as e:
30
+ print(f"Unable to count PubMed files: {e}")
31
+
32
+ return 0
33
+
34
+
35
+ def download_pubmed_xml(output_dir, num_files=1, year='25'):
36
+ os.makedirs(output_dir, exist_ok=True)
37
+
38
+ total_dataset_size = get_pubmed_dataset_size()
39
+
40
+ files = []
41
+ pbar = tqdm(total=total_dataset_size, desc=f"Downloading {num_files}/{total_dataset_size} files in PubMed dataset")
42
+
43
+ for i in range(1, num_files + 1):
44
+ filename = f"pubmed{year}n{i:04d}.xml.gz"
45
+ filepath = os.path.join(output_dir, filename)
46
+
47
+ if not os.path.exists(filepath):
48
+ urlretrieve(f"{PUBMED_DATASET_BASE_URL}/{filename}", filepath)
49
+
50
+ pbar.update(1)
51
+
52
+ files.append(filepath)
53
+
54
+ pbar.close()
55
+
56
+ return files
57
+
58
+
59
+ def parse_pubmed_to_jsonl(xml_files, output_jsonl):
60
+ with open(output_jsonl, 'w') as out:
61
+ for xml_file in xml_files:
62
+ print(f"Parsing {xml_file}...")
63
+ with gzip.open(xml_file, 'rt', encoding='utf-8') as f:
64
+ tree = ET.parse(f)
65
+ root = tree.getroot()
66
+
67
+ for article in tqdm(root.findall('.//PubmedArticle')):
68
+ pmid_elem = article.find('.//PMID')
69
+ title_elem = article.find('.//ArticleTitle')
70
+ abstract_elem = article.find('.//Abstract/AbstractText')
71
+
72
+ if pmid_elem is not None:
73
+ title = title_elem.text if title_elem is not None else ""
74
+ abstract = abstract_elem.text if abstract_elem is not None else ""
75
+
76
+ doc = {
77
+ 'id': pmid_elem.text,
78
+ 'title': title,
79
+ 'contents': f"{title} {abstract}".strip()
80
+ }
81
+ out.write(json.dumps(doc) + '\n')
82
+
83
+
84
+ def download_pubmed(output_jsonl, num_files=1):
85
+ if os.path.exists(output_jsonl):
86
+ print(f"Already downloaded PubMed dataset: {output_jsonl}")
87
+
88
+ return
89
+
90
+ xml_dir = os.path.join(os.path.dirname(output_jsonl), '../pubmed-xml')
91
+ xml_files = download_pubmed_xml(xml_dir, num_files=num_files)
92
+ parse_pubmed_to_jsonl(xml_files, output_jsonl)
93
+
94
+
95
+ def build_index_cmd(input_file, index_dir):
96
+ return [
97
+ "python", "-m", "pyserini.index.lucene",
98
+ "--collection", "JsonCollection",
99
+ "--input", os.path.dirname(input_file),
100
+ "--index", index_dir,
101
+ "--generator", "DefaultLuceneDocumentGenerator",
102
+ "--threads", "32",
103
+ "--storePositions",
104
+ "--storeDocvectors",
105
+ "--storeRaw",
106
+ ]
107
+
108
+
109
+ def build_index(input_file, index_dir, cmd_generator=build_index_cmd):
110
+ if os.path.exists(index_dir) and os.listdir(index_dir):
111
+ print(f"Skipping existing index: {index_dir}")
112
+
113
+ return
114
+
115
+ os.makedirs(os.path.dirname(index_dir) or '.', exist_ok=True)
116
+
117
+ cmd = cmd_generator(input_file, index_dir)
118
+
119
+ subprocess.run(cmd, check=True)
120
+
121
+
122
+ def main(base_data_dir="data", base_index_dir="indexes", num_files=1):
123
+ corpus_jsonl = os.path.join(base_data_dir, "pubmed", "corpus.jsonl")
124
+ index_dir = os.path.join(base_index_dir, "pubmed")
125
+
126
+ download_pubmed(corpus_jsonl, num_files=num_files)
127
+
128
+ build_index(corpus_jsonl, index_dir)
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main(num_files=PUBMED_FILE_LIMIT)
app_retrieval_cached.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Medical Q&A UI - BM25 + Dense Retrieval Models WITH DISK CACHING
3
+ This version caches the indexes to disk for fast startup (30 seconds vs 5-8 minutes!)
4
+ """
5
+
6
+ import gradio as gr
7
+ from typing import Dict, List
8
+ from pathlib import Path
9
+ import pickle
10
+ import hashlib
11
+ import json
12
+ from retriever.index_bm25 import BM25Index
13
+ from retriever.index_dense import DenseIndex
14
+ from retriever.ingest import load_jsonl
15
+ from retriever.rrf import rrf
16
+ from team.interfaces import Candidate
17
+
18
+ # Cache directory
19
+ CACHE_DIR = Path("cache")
20
+ CACHE_DIR.mkdir(exist_ok=True)
21
+
22
+ print("=" * 70)
23
+ print(" Medical Document Retrieval System (CACHED VERSION)")
24
+ print(" Using BM25 + Dense Embeddings + RRF Fusion")
25
+ print(" With disk caching for fast startup!")
26
+ print("=" * 70)
27
+
28
+
29
+ def _default_corpora_config() -> Dict[str, dict]:
30
+ return {
31
+ "medical_qa": {"path": "data/corpora/medical_qa.jsonl",
32
+ "text_fields": ["question", "answer", "title"]},
33
+ "miriad": {"path": "data/corpora/miriad_text.jsonl",
34
+ "text_fields": ["question", "answer", "title"]},
35
+ "unidoc": {"path": "data/corpora/unidoc_qa.jsonl",
36
+ "text_fields": ["question", "answer", "title"]},
37
+ }
38
+
39
+
40
+ def _available(cfg: Dict[str, dict]) -> Dict[str, dict]:
41
+ return {k: v for k, v in cfg.items() if Path(v["path"]).exists()}
42
+
43
+
44
+ def _get_cache_key(corpora_config: Dict[str, dict]) -> str:
45
+ """Generate a unique cache key based on corpora config"""
46
+ config_str = json.dumps(corpora_config, sort_keys=True)
47
+ return hashlib.md5(config_str.encode()).hexdigest()
48
+
49
+
50
+ class CachedRetriever:
51
+ """Retriever with disk caching for BM25 and Dense indexes"""
52
+
53
+ def __init__(self, corpora_config: Dict[str, dict], use_reranker: bool = False):
54
+ self.corpora_config = corpora_config
55
+ self.use_reranker = use_reranker
56
+ self.cache_key = _get_cache_key(corpora_config)
57
+
58
+ # Cache file paths
59
+ self.bm25_cache = CACHE_DIR / f"bm25_{self.cache_key}.pkl"
60
+ self.dense_cache = CACHE_DIR / f"dense_{self.cache_key}.pkl"
61
+ self.docs_cache = CACHE_DIR / f"docs_{self.cache_key}.pkl"
62
+
63
+ # Load or build indexes
64
+ self.docs_all = self._load_or_build_docs()
65
+ self.bm25 = self._load_or_build_bm25()
66
+ self.dense = self._load_or_build_dense()
67
+
68
+ def _load_or_build_docs(self) -> List:
69
+ """Load documents from cache or build from scratch"""
70
+ if self.docs_cache.exists():
71
+ print(f"Loading documents from cache... ({self.docs_cache.name})")
72
+ try:
73
+ with open(self.docs_cache, 'rb') as f:
74
+ docs_all = pickle.load(f)
75
+ print(f" ✓ Loaded {len(docs_all)} documents from cache")
76
+ return docs_all
77
+ except Exception as e:
78
+ print(f" ✗ Cache load failed: {e}")
79
+ print(" → Rebuilding documents...")
80
+
81
+ print("Building documents from corpora files...")
82
+ docs_all = []
83
+ for name, cfg in self.corpora_config.items():
84
+ print(f" Loading {name}...")
85
+ docs = load_jsonl(cfg["path"], tuple(cfg.get("text_fields", ("question", "answer"))))
86
+ docs_all.extend(docs)
87
+
88
+ # Save to cache
89
+ print(f"Saving documents to cache... ({len(docs_all)} docs)")
90
+ with open(self.docs_cache, 'wb') as f:
91
+ pickle.dump(docs_all, f)
92
+
93
+ return docs_all
94
+
95
+ def _load_or_build_bm25(self) -> BM25Index:
96
+ """Load BM25 index from cache or build from scratch"""
97
+ if self.bm25_cache.exists():
98
+ print(f"Loading BM25 index from cache... ({self.bm25_cache.name})")
99
+ try:
100
+ with open(self.bm25_cache, 'rb') as f:
101
+ bm25_index = pickle.load(f)
102
+ print(f" ✓ BM25 index loaded from cache")
103
+ return bm25_index
104
+ except Exception as e:
105
+ print(f" ✗ Cache load failed: {e}")
106
+ print(" → Rebuilding BM25 index...")
107
+
108
+ print("Building BM25 index from scratch...")
109
+ bm25_index = BM25Index(self.docs_all)
110
+
111
+ # Save to cache
112
+ print(f"Saving BM25 index to cache...")
113
+ with open(self.bm25_cache, 'wb') as f:
114
+ pickle.dump(bm25_index, f)
115
+
116
+ return bm25_index
117
+
118
+ def _load_or_build_dense(self) -> DenseIndex:
119
+ """Load Dense index from cache or build from scratch"""
120
+ if self.dense_cache.exists():
121
+ print(f"Loading Dense index from cache... ({self.dense_cache.name})")
122
+ try:
123
+ with open(self.dense_cache, 'rb') as f:
124
+ dense_index = pickle.load(f)
125
+ print(f" ✓ Dense index loaded from cache")
126
+ return dense_index
127
+ except Exception as e:
128
+ print(f" ✗ Cache load failed: {e}")
129
+ print(" → Rebuilding Dense index...")
130
+
131
+ print("Building Dense index from scratch (this takes 5-8 minutes)...")
132
+ dense_index = DenseIndex(self.docs_all)
133
+
134
+ # Save to cache
135
+ print(f"Saving Dense index to cache...")
136
+ with open(self.dense_cache, 'wb') as f:
137
+ pickle.dump(dense_index, f)
138
+
139
+ return dense_index
140
+
141
+
142
+ # Initialize cached retriever (fast if cached, slow first time)
143
+ print("\nInitializing retrieval system...")
144
+ cfg = _available(_default_corpora_config())
145
+ if not cfg:
146
+ raise RuntimeError("No corpora files found in data/corpora. Build them first.")
147
+
148
+ retriever = CachedRetriever(corpora_config=cfg, use_reranker=False)
149
+
150
+ print("\n✓ Retrieval system ready!")
151
+ print(f" Total documents indexed: {len(retriever.docs_all):,}")
152
+ print("=" * 70)
153
+
154
+
155
+ def get_candidates_cached(query: str, k_retrieve: int = 50) -> List[Candidate]:
156
+ """
157
+ Returns top-N fused candidates with component scores (bm25, dense, rrf).
158
+ Uses the cached retriever for fast queries.
159
+ """
160
+ # Get separate result lists (doc, score)
161
+ bm = retriever.bm25.search(query, k=max(k_retrieve, 100))
162
+ de = retriever.dense.search(query, k=max(k_retrieve, 100))
163
+
164
+ # Maps for score lookup
165
+ bm_map = {d.id: float(s) for d, s in bm}
166
+ de_map = {d.id: float(s) for d, s in de}
167
+
168
+ # Fuse and pick candidate set
169
+ fused = rrf([bm, de], k=max(k_retrieve, 50))
170
+
171
+ # Compute RRF per candidate using rank positions
172
+ K = 60
173
+ bm_rank = {d.id: i for i, (d, _) in enumerate(bm)}
174
+ de_rank = {d.id: i for i, (d, _) in enumerate(de)}
175
+
176
+ out: List[Candidate] = []
177
+ for doc, _ in fused[:k_retrieve]:
178
+ rrf_score = 0.0
179
+ if doc.id in bm_rank:
180
+ rrf_score += 1.0 / (K + bm_rank[doc.id] + 1)
181
+ if doc.id in de_rank:
182
+ rrf_score += 1.0 / (K + de_rank[doc.id] + 1)
183
+ out.append(Candidate(
184
+ id=doc.id,
185
+ title=doc.title or "",
186
+ text=doc.text,
187
+ meta=doc.meta or {},
188
+ bm25=bm_map.get(doc.id, 0.0),
189
+ dense=de_map.get(doc.id, 0.0),
190
+ rrf=rrf_score,
191
+ ))
192
+ # Baseline order: RRF
193
+ out.sort(key=lambda c: c.rrf, reverse=True)
194
+ return out
195
+
196
+
197
+ def retrieve_documents(query, num_results=5):
198
+ """Retrieve relevant medical documents using your team's models"""
199
+ if not query or not query.strip():
200
+ return """
201
+ <div style="padding: 20px; background-color: #e7f3ff; border-radius: 10px; border-left: 5px solid #2196f3;">
202
+ <h3 style="margin-top: 0; color: #0d47a1;">How to Use</h3>
203
+ <p style="margin: 0; color: #1565c0;">Enter a medical query and we'll find relevant documents using BM25 + Dense retrieval with RRF fusion.</p>
204
+ <p style="margin: 8px 0 0 0; color: #1565c0;"><strong>Example:</strong> "headache with blurred vision" or "symptoms of diabetes"</p>
205
+ </div>
206
+ """
207
+
208
+ try:
209
+ # Use cached retrieval system (fast!)
210
+ hits = get_candidates_cached(query=query, k_retrieve=num_results)
211
+
212
+ if not hits:
213
+ return """
214
+ <div style="padding: 20px; background-color: #fff3cd; border-radius: 10px; border-left: 5px solid #ffc107;">
215
+ <h3 style="margin-top: 0; color: #856404;">No Results Found</h3>
216
+ <p style="margin: 0; color: #856404;">Try rephrasing your query or using different medical terms.</p>
217
+ </div>
218
+ """
219
+
220
+ # Build results HTML
221
+ result_html = f"""
222
+ <div style="padding: 15px; background: linear-gradient(135deg, #d4edda 0%, #c3e6cb 100%); border-radius: 10px; margin-bottom: 20px; border-left: 5px solid #28a745;">
223
+ <h3 style="margin-top: 0; color: #155724;">Found {len(hits)} Relevant Medical Documents</h3>
224
+ <p style="margin: 0;"><strong>Retrieved using:</strong> BM25 + Dense Embeddings + RRF Fusion (CACHED)</p>
225
+ </div>
226
+ """
227
+
228
+ for i, hit in enumerate(hits, 1):
229
+ title = hit.title if hit.title and hit.title.strip() else None
230
+ source = hit.meta.get('source', 'Unknown') if hit.meta else 'Unknown'
231
+
232
+ # Check if we have separate question/answer fields in metadata
233
+ question = hit.meta.get('question', '') if hit.meta else ''
234
+ answer = hit.meta.get('answer', '') if hit.meta else ''
235
+
236
+ # If we have separate Q&A, format them nicely
237
+ if question and answer:
238
+ content_html = f"""
239
+ <div style="margin-bottom: 12px;">
240
+ <strong style="color: #1976d2;">Question:</strong>
241
+ <p style="margin: 5px 0 0 0; line-height: 1.6; color: #424242;">{question}</p>
242
+ </div>
243
+ <div>
244
+ <strong style="color: #388e3c;">Answer:</strong>
245
+ <p style="margin: 5px 0 0 0; line-height: 1.6; color: #424242;">{answer[:500] + ("..." if len(answer) > 500 else "")}</p>
246
+ </div>
247
+ """
248
+ else:
249
+ # Fallback to combined text
250
+ text = hit.text[:500] + ("..." if len(hit.text) > 500 else "")
251
+ content_html = f'<p style="margin: 0; line-height: 1.7; color: #34495e;">{text}</p>'
252
+
253
+ # Display relevance scores
254
+ bm25_score = hit.bm25
255
+ dense_score = hit.dense
256
+ rrf_score = hit.rrf
257
+
258
+ # Build title HTML only if title exists
259
+ title_html = f'<h4 style="margin: 0 0 15px 0; color: #2c3e50;">{title}</h4>' if title else ''
260
+
261
+ result_html += f"""
262
+ <div style="border: 2px solid #dee2e6; padding: 20px; margin: 20px 0; border-radius: 10px; background-color: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
263
+ <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; margin: -20px -20px 20px -20px; border-radius: 8px 8px 0 0;">
264
+ <div style="display: flex; justify-content: space-between; align-items: center;">
265
+ <h4 style="margin: 0; color: white;">Document #{i}</h4>
266
+ <span style="background-color: rgba(255,255,255,0.2); padding: 4px 12px; border-radius: 12px; font-size: 0.85em; color: white;">
267
+ {source}
268
+ </span>
269
+ </div>
270
+ </div>
271
+
272
+ <div style="margin-bottom: 15px;">
273
+ {title_html}
274
+ {content_html}
275
+ </div>
276
+
277
+ <div style="padding-top: 12px; border-top: 1px solid #e9ecef;">
278
+ <div style="display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 10px;">
279
+ <div style="background-color: #e3f2fd; padding: 8px; border-radius: 5px; text-align: center;">
280
+ <div style="font-size: 0.75em; color: #1976d2; font-weight: bold;">BM25</div>
281
+ <div style="font-size: 1.1em; color: #0d47a1;">{bm25_score:.4f}</div>
282
+ </div>
283
+ <div style="background-color: #f3e5f5; padding: 8px; border-radius: 5px; text-align: center;">
284
+ <div style="font-size: 0.75em; color: #7b1fa2; font-weight: bold;">Dense</div>
285
+ <div style="font-size: 1.1em; color: #4a148c;">{dense_score:.4f}</div>
286
+ </div>
287
+ <div style="background-color: #e8f5e9; padding: 8px; border-radius: 5px; text-align: center;">
288
+ <div style="font-size: 0.75em; color: #388e3c; font-weight: bold;">RRF Fusion</div>
289
+ <div style="font-size: 1.1em; color: #1b5e20;">{rrf_score:.4f}</div>
290
+ </div>
291
+ </div>
292
+ </div>
293
+ </div>
294
+ """
295
+
296
+ return result_html
297
+
298
+ except Exception as e:
299
+ return f"""
300
+ <div style="padding: 20px; background-color: #f8d7da; border-radius: 10px; border-left: 5px solid #dc3545;">
301
+ <h3 style="margin-top: 0; color: #721c24;">Error</h3>
302
+ <p style="margin: 0; color: #721c24;">{str(e)}</p>
303
+ </div>
304
+ """
305
+
306
+
307
+ # Create Gradio interface
308
+ with gr.Blocks(title="Medical Document Retrieval (Cached)") as demo:
309
+ gr.Markdown("""
310
+ # Medical Document Retrieval System (CACHED VERSION)
311
+
312
+ **Models:**
313
+ - BM25 Index (keyword-based retrieval)
314
+ - Dense Embeddings (embeddinggemma-300m-medical)
315
+ - RRF Fusion (combines both approaches)
316
+
317
+ ### Features:
318
+ - Searches across 10,000+ medical documents
319
+ - Shows relevance scores from each model component
320
+ - Returns the most relevant medical information
321
+ """)
322
+
323
+ with gr.Row():
324
+ with gr.Column():
325
+ query_input = gr.Textbox(
326
+ label="Enter your medical query",
327
+ placeholder="Example: headache with blurred vision",
328
+ lines=2
329
+ )
330
+ num_results = gr.Slider(
331
+ minimum=1,
332
+ maximum=10,
333
+ value=5,
334
+ step=1,
335
+ label="Number of results to retrieve"
336
+ )
337
+ submit_btn = gr.Button("Retrieve Documents", variant="primary", size="lg")
338
+
339
+ output_html = gr.HTML(label="Search Results")
340
+
341
+ submit_btn.click(
342
+ fn=retrieve_documents,
343
+ inputs=[query_input, num_results],
344
+ outputs=output_html
345
+ )
346
+
347
+ gr.Examples(
348
+ examples=[
349
+ "headache with blurred vision",
350
+ "symptoms of diabetes",
351
+ "chest pain when exercising",
352
+ "treatment for high blood pressure",
353
+ "causes of chronic fatigue",
354
+ ],
355
+ inputs=query_input,
356
+ label="Try these example queries:"
357
+ )
358
+
359
+ gr.Markdown("""
360
+ ---
361
+ ### Technical Details
362
+ - **BM25**: Statistical keyword matching (TF-IDF based)
363
+ - **Dense**: Semantic search using transformer embeddings
364
+ - **RRF Fusion**: Reciprocal Rank Fusion combines both methods
365
+ - **Caching**: Indexes saved to disk in `cache/` folder for fast reloading
366
+
367
+ *Note: First launch builds and caches indexes (5-8 min). After that, startup takes only ~30 seconds!*
368
+ """)
369
+
370
+ print("\nOpening web interface...")
371
+ print(" Local access: http://127.0.0.1:7863")
372
+ print(" Public link will be generated...")
373
+ print("=" * 70)
374
+
375
+ if __name__ == "__main__":
376
+ demo.launch(server_name="127.0.0.1", server_port=7863, share=True)
classifier/__init__.py ADDED
File without changes
classifier/config.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def load_env():
4
+ if os.path.exists("env.list"):
5
+ with open("env.list", "r") as f:
6
+ for line in f:
7
+ line = line.strip()
8
+ if line and not line.startswith("#"):
9
+ key, value = line.split("=", 1)
10
+ os.environ[key] = value
11
+
12
+ load_env()
13
+ HF_TOKEN = os.getenv("HF_TOKEN")
classifier/head.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ from torch import nn
3
+ import torch
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+ class ClassifierHead(
7
+ nn.Module,
8
+ PyTorchModelHubMixin,
9
+ repo_url="https://huggingface.co/davidgray/health-query-triage",
10
+ pipeline_tag="text-classification",
11
+ library_name="PyTorch",
12
+ tags=["medical", "classification"],
13
+ ):
14
+ def __init__(self, num_classes: int, embedding_dim: int = 768): # Embedding-Gemma-300M has a 768-dimensional output
15
+ super().__init__()
16
+
17
+ self.linear_elu_stack = nn.Sequential(
18
+ nn.Linear(embedding_dim, 512),
19
+ nn.ELU(),
20
+ nn.Dropout(0.5),
21
+ nn.Linear(512, 512),
22
+ nn.ELU(),
23
+ nn.Dropout(0.5),
24
+ nn.Linear(512, num_classes),
25
+ )
26
+
27
+ self.softmax = nn.Softmax(dim=-1)
28
+
29
+ def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
30
+ """
31
+ Calculates logits from the sentence embedding.
32
+
33
+ Args:
34
+ features (Dict[str, torch.Tensor]): Output dictionary from the Sentence Transformer body,
35
+ containing 'sentence_embedding'.
36
+ Returns:
37
+ Dict[str, torch.Tensor]: Dictionary with the 'logits' key.
38
+ """
39
+ embeddings = features['sentence_embedding']
40
+ logits = self.linear_elu_stack(embeddings)
41
+ return {"logits": logits}
42
+
43
+ def predict(self, embeddings: torch.Tensor) -> torch.Tensor:
44
+ """
45
+ Classifies embeddings into integer labels in the range [0, num_classes).
46
+
47
+ Args:
48
+ embeddings (torch.Tensor): Tensor with shape [num_inputs, embedding_size].
49
+
50
+ Returns:
51
+ torch.Tensor: Integer labels with shape [num_inputs].
52
+ """
53
+ # Get probabilities and find the class with the highest probability
54
+ proba = self.predict_proba(embeddings)
55
+ return torch.argmax(proba, dim=-1)
56
+
57
+ def predict_proba(self, embeddings: torch.Tensor) -> torch.Tensor:
58
+ """
59
+ Classifies embeddings into probabilities for each class (summing to 1).
60
+
61
+ Args:
62
+ embeddings (torch.Tensor): Tensor with shape [num_inputs, embedding_size].
63
+
64
+ Returns:
65
+ torch.Tensor: Float probabilities with shape [num_inputs, num_classes].
66
+ """
67
+ # Apply the forward pass of the head to get logits
68
+ self.eval()
69
+ with torch.no_grad():
70
+ logits = self.linear_elu_stack(embeddings)
71
+ # Convert logits to probabilities using Softmax
72
+ probabilities = self.softmax(logits)
73
+ self.train() # Set back to training mode
74
+
75
+ return probabilities
76
+
77
+ def get_loss_fn(self) -> nn.Module:
78
+ """
79
+ Returns an initialized loss function for training.
80
+
81
+ Returns:
82
+ nn.Module: An initialized loss function (e.g., CrossEntropyLoss).
83
+ """
84
+ # CrossEntropyLoss expects logits (raw scores) as input
85
+ return nn.CrossEntropyLoss()
classifier/infer.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from classifier.head import ClassifierHead
2
+ from classifier.utils import CATEGORIES, CHECKPOINT_PATH, DEVICE, get_models, CLASSIFIER_NAME, get_latest_checkpoint
3
+
4
+ import argparse
5
+ import pprint
6
+ import torch
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ def classifier_init(checkpoint_path: str | None = None, model_id: str | None = CLASSIFIER_NAME) -> (SentenceTransformer, ClassifierHead):
10
+ if checkpoint_path:
11
+ latest_checkpoint = get_latest_checkpoint(checkpoint_path)
12
+ print(f"Loading checkpoint from {latest_checkpoint}")
13
+ embedding_model, classifier = get_models(model_id=latest_checkpoint)
14
+ else:
15
+ embedding_model, classifier = get_models(model_id=model_id)
16
+
17
+ return embedding_model, classifier
18
+
19
+ def predict_query(
20
+ text: list[str],
21
+ embedding_model: SentenceTransformer,
22
+ classifier_head: ClassifierHead,
23
+ ) -> dict:
24
+ """
25
+ Runs the full inference pipeline: Text -> Embedding -> Classification.
26
+ """
27
+ # Set models to evaluation mode
28
+ embedding_model.eval()
29
+ classifier_head.eval()
30
+
31
+ with torch.no_grad():
32
+ # Embed the text
33
+ embeddings = embedding_model.encode(
34
+ text,
35
+ convert_to_tensor=True,
36
+ device=DEVICE
37
+ ).to(DEVICE)
38
+
39
+ # Calculate probabilities and prediction
40
+ probabilities = classifier_head.predict_proba(embeddings)
41
+
42
+ # Get the predicted index and confidence
43
+ predicted_indices = torch.argmax(probabilities, dim=1).unsqueeze(1)
44
+ confidences = torch.gather(probabilities, dim=1, index=predicted_indices).squeeze().tolist()
45
+
46
+ # Get the predicted label name
47
+ predicted_labels = [CATEGORIES[i] for i in predicted_indices]
48
+
49
+ return {
50
+ 'prediction': predicted_labels,
51
+ 'confidence': confidences,
52
+ 'probabilities': probabilities.cpu().squeeze().tolist()
53
+ }
54
+
55
+ def test(local: bool = False):
56
+ embedding_model, classifier = classifier_init(checkpoint_path=CHECKPOINT_PATH if local else None)
57
+
58
+ queries = [
59
+ "Hi! I'm having a really bad rash on my hands. I'm pretty sure it's my excema flairing up. Is there anythign stronger than aquaphor I can use on it?",
60
+ "Hey is there any way I can get an appointment in the next month?",
61
+ "Hey is there any way I can get an appointment in the next month with a doctor?",
62
+ "I'm traveling to South America soon. Do I need to get any vaccines before I go?",
63
+ "I have this rash that popped up today.",
64
+ "How can I make this hosptial bill go away?",
65
+ "I'm so confused do I have to cover the full cost of this operation?",
66
+ ]
67
+
68
+ pred = predict_query(
69
+ text=queries,
70
+ embedding_model=embedding_model,
71
+ classifier_head=classifier,
72
+ )
73
+
74
+ pprint.pprint(pred, indent=4)
75
+
76
+ if __name__ == "__main__":
77
+ ap = argparse.ArgumentParser(
78
+ description="Inference on a classifier for triaging health queries"
79
+ )
80
+ ap.add_argument(
81
+ "--local", action="store_true",
82
+ help="Use local checkpoint"
83
+ )
84
+ args = ap.parse_args()
85
+
86
+ test(local=args.local)
classifier/modelcard_template.md ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
3
+ # Doc / guide: https://huggingface.co/docs/hub/model-cards
4
+ {{ card_data }}
5
+ ---
6
+
7
+ # Model Card for {{ model_id | default("Model ID", true) }}
8
+
9
+ <!-- Provide a quick summary of what the model is/does. -->
10
+
11
+ {{ model_summary | default("", true) }}
12
+
13
+ ## Model Details
14
+
15
+ ### Model Description
16
+
17
+ <!-- Provide a longer summary of what this model is. -->
18
+
19
+ {{ model_description | default("", true) }}
20
+
21
+ - **Developed by:** {{ developers | default("[More Information Needed]", true)}}
22
+ - **Funded by [optional]:** {{ funded_by | default("[More Information Needed]", true)}}
23
+ - **Shared by [optional]:** {{ shared_by | default("[More Information Needed]", true)}}
24
+ - **Model type:** {{ model_type | default("[More Information Needed]", true)}}
25
+ - **Language(s) (NLP):** {{ language | default("[More Information Needed]", true)}}
26
+ - **License:** {{ license | default("[More Information Needed]", true)}}
27
+ - **Finetuned from model [optional]:** {{ base_model | default("[More Information Needed]", true)}}
28
+
29
+ ### Model Sources [optional]
30
+
31
+ <!-- Provide the basic links for the model. -->
32
+
33
+ - **Repository:** {{ repo | default("[More Information Needed]", true)}}
34
+ - **Paper [optional]:** {{ paper | default("[More Information Needed]", true)}}
35
+ - **Demo [optional]:** {{ demo | default("[More Information Needed]", true)}}
36
+
37
+ ## Uses
38
+
39
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
40
+
41
+ ### Direct Use
42
+
43
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
44
+
45
+ {{ direct_use | default("[More Information Needed]", true)}}
46
+
47
+ ### Downstream Use [optional]
48
+
49
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
50
+
51
+ {{ downstream_use | default("[More Information Needed]", true)}}
52
+
53
+ ### Out-of-Scope Use
54
+
55
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
56
+
57
+ {{ out_of_scope_use | default("[More Information Needed]", true)}}
58
+
59
+ ## Bias, Risks, and Limitations
60
+
61
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
62
+
63
+ {{ bias_risks_limitations | default("[More Information Needed]", true)}}
64
+
65
+ ### Recommendations
66
+
67
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
68
+
69
+ {{ bias_recommendations | default("Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.", true)}}
70
+
71
+ ## How to Get Started with the Model
72
+
73
+ Use the code below to get started with the model.
74
+
75
+ {{ get_started_code | default("[More Information Needed]", true)}}
76
+
77
+ ## Training Details
78
+
79
+ ### Training Data
80
+
81
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
82
+
83
+ {{ training_data | default("[More Information Needed]", true)}}
84
+
85
+ ### Training Procedure
86
+
87
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
88
+
89
+ #### Preprocessing [optional]
90
+
91
+ {{ preprocessing | default("[More Information Needed]", true)}}
92
+
93
+
94
+ #### Training Hyperparameters
95
+
96
+ - **Training regime:** {{ training_regime | default("[More Information Needed]", true)}} <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
97
+
98
+ #### Speeds, Sizes, Times [optional]
99
+
100
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
101
+
102
+ {{ speeds_sizes_times | default("[More Information Needed]", true)}}
103
+
104
+ ## Evaluation
105
+
106
+ <!-- This section describes the evaluation protocols and provides the results. -->
107
+
108
+ ### Testing Data, Factors & Metrics
109
+
110
+ #### Testing Data
111
+
112
+ <!-- This should link to a Dataset Card if possible. -->
113
+
114
+ {{ testing_data | default("[More Information Needed]", true)}}
115
+
116
+ #### Factors
117
+
118
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
119
+
120
+ {{ testing_factors | default("[More Information Needed]", true)}}
121
+
122
+ #### Metrics
123
+
124
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
125
+
126
+ {{ testing_metrics | default("[More Information Needed]", true)}}
127
+
128
+ ### Results
129
+
130
+ {{ results | default("[More Information Needed]", true)}}
131
+
132
+ #### Summary
133
+
134
+ {{ results_summary | default("", true) }}
135
+
136
+ ## Model Examination [optional]
137
+
138
+ <!-- Relevant interpretability work for the model goes here -->
139
+
140
+ {{ model_examination | default("[More Information Needed]", true)}}
141
+
142
+ ## Environmental Impact
143
+
144
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
145
+
146
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
147
+
148
+ - **Hardware Type:** {{ hardware_type | default("[More Information Needed]", true)}}
149
+ - **Hours used:** {{ hours_used | default("[More Information Needed]", true)}}
150
+ - **Cloud Provider:** {{ cloud_provider | default("[More Information Needed]", true)}}
151
+ - **Compute Region:** {{ cloud_region | default("[More Information Needed]", true)}}
152
+ - **Carbon Emitted:** {{ co2_emitted | default("[More Information Needed]", true)}}
153
+
154
+ ## Technical Specifications [optional]
155
+
156
+ ### Model Architecture and Objective
157
+
158
+ {{ model_specs | default("[More Information Needed]", true)}}
159
+
160
+ ### Compute Infrastructure
161
+
162
+ {{ compute_infrastructure | default("[More Information Needed]", true)}}
163
+
164
+ #### Hardware
165
+
166
+ {{ hardware_requirements | default("[More Information Needed]", true)}}
167
+
168
+ #### Software
169
+
170
+ {{ software | default("[More Information Needed]", true)}}
171
+
172
+ ## Citation [optional]
173
+
174
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
175
+
176
+ **BibTeX:**
177
+
178
+ {{ citation_bibtex | default("[More Information Needed]", true)}}
179
+
180
+ **APA:**
181
+
182
+ {{ citation_apa | default("[More Information Needed]", true)}}
183
+
184
+ ## Glossary [optional]
185
+
186
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
187
+
188
+ {{ glossary | default("[More Information Needed]", true)}}
189
+
190
+ ## More Information [optional]
191
+
192
+ {{ more_information | default("[More Information Needed]", true)}}
193
+
194
+ ## Model Card Authors [optional]
195
+
196
+ {{ model_card_authors | default("[More Information Needed]", true)}}
197
+
198
+ ## Model Card Contact
199
+
200
+ {{ model_card_contact | default("[More Information Needed]", true)}}
classifier/query_router.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Query Router System
3
+
4
+ This module integrates the medical/insurance classifier with the reason
5
+ classification system to provide intelligent routing of healthcare portal queries.
6
+
7
+ The router first determines if a query is medical or insurance-related, then
8
+ routes accordingly:
9
+ - Insurance queries -> Direct to insurance department
10
+ - Medical queries -> Reason classification -> Appropriate medical department routing
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ from typing import Dict, List, Optional, Tuple
16
+ from pathlib import Path
17
+
18
+ # Add project root to path for imports
19
+ REPO_ROOT = Path(__file__).resolve().parents[1]
20
+ if str(REPO_ROOT) not in sys.path:
21
+ sys.path.insert(0, str(REPO_ROOT))
22
+
23
+ from classifier.infer import predict_query
24
+ from classifier.utils import get_models, CATEGORIES
25
+ from classifier.reason import predict_single_reason
26
+ from retriever.search import Retriever
27
+ from team.candidates import get_candidates
28
+
29
+ class HealthcareQueryRouter:
30
+ """
31
+ Intelligent routing system for healthcare portal queries.
32
+
33
+ Routes queries through a two-stage process:
34
+ 1. Medical vs Insurance classification
35
+ 2. For medical queries: Reason classification for department routing
36
+ """
37
+
38
+ def __init__(self,
39
+ medical_model_path: Optional[str] = None,
40
+ use_retrieval: bool = True):
41
+ """
42
+ Initialize the query router.
43
+
44
+ Args:
45
+ medical_model_path: Path to trained medical/insurance classifier
46
+ use_retrieval: Whether to use retrieval system for medical queries
47
+ """
48
+
49
+ # Initialize medical/insurance classifier
50
+ try:
51
+ self.embedding_model, self.classifier_head = get_models()
52
+
53
+ # Load trained model if available
54
+ if medical_model_path and os.path.exists(medical_model_path):
55
+ import torch
56
+ state_dict = torch.load(medical_model_path, weights_only=True)
57
+ self.classifier_head.load_state_dict(state_dict)
58
+ print(f"Loaded medical/insurance classifier from {medical_model_path}")
59
+ else:
60
+ print("Using untrained medical/insurance classifier")
61
+
62
+ except Exception as e:
63
+ print(f"Error initializing medical/insurance classifier: {e}")
64
+ raise
65
+
66
+ # Initialize retrieval system if requested
67
+ self.retriever = None
68
+ if use_retrieval:
69
+ try:
70
+ # Use default corpora configuration
71
+ corpora_config = {
72
+ "medical_qa": {
73
+ "path": "data/corpora/medical_qa.jsonl",
74
+ "text_fields": ["question", "answer", "title"],
75
+ },
76
+ "miriad": {
77
+ "path": "data/corpora/miriad_text.jsonl",
78
+ "text_fields": ["text", "title"],
79
+ }
80
+ }
81
+ # Only use available corpora
82
+ available_config = {k: v for k, v in corpora_config.items()
83
+ if Path(v["path"]).exists()}
84
+
85
+ if available_config:
86
+ self.retriever = Retriever(available_config)
87
+ print(f"Retrieval system initialized with {len(available_config)} corpora")
88
+ else:
89
+ print("No corpora files found. Retrieval disabled.")
90
+ except Exception as e:
91
+ print(f"Could not initialize retrieval system: {e}")
92
+
93
+ # Routing rules for insurance queries
94
+ self.insurance_routing = {
95
+ "department": "Insurance Department",
96
+ "priority": "normal",
97
+ "estimated_response": "1-2 business days",
98
+ "contact_method": "phone_or_email",
99
+ "description": "Insurance coverage, claims, and benefits inquiries"
100
+ }
101
+
102
+ # Medical department routing based on reason categories
103
+ self.medical_department_routing = {
104
+ "ROUTINE_CARE": {
105
+ "department": "Primary Care",
106
+ "priority": "normal",
107
+ "estimated_response": "1-7 days",
108
+ "contact_method": "standard_scheduling",
109
+ "description": "Routine healthcare and maintenance visits"
110
+ },
111
+ "PAIN_CONDITIONS": {
112
+ "department": "Pain Management",
113
+ "priority": "high",
114
+ "estimated_response": "same day to 3 days",
115
+ "contact_method": "phone_preferred",
116
+ "description": "Pain-related conditions and discomfort"
117
+ },
118
+ "INJURIES": {
119
+ "department": "Urgent Care",
120
+ "priority": "high",
121
+ "estimated_response": "same day",
122
+ "contact_method": "phone_immediate",
123
+ "description": "Injuries, sprains, and trauma-related conditions"
124
+ },
125
+ "SKIN_CONDITIONS": {
126
+ "department": "Dermatology",
127
+ "priority": "normal",
128
+ "estimated_response": "3-7 days",
129
+ "contact_method": "standard_scheduling",
130
+ "description": "Skin-related issues and conditions"
131
+ },
132
+ "STRUCTURAL_ISSUES": {
133
+ "department": "Orthopedics",
134
+ "priority": "normal",
135
+ "estimated_response": "1-14 days",
136
+ "contact_method": "standard_scheduling",
137
+ "description": "Structural problems and musculoskeletal conditions"
138
+ },
139
+ "PROCEDURES": {
140
+ "department": "Surgical Services",
141
+ "priority": "normal",
142
+ "estimated_response": "3-14 days",
143
+ "contact_method": "scheduling_coordinator",
144
+ "description": "Surgical consultations and procedures"
145
+ }
146
+ }
147
+
148
+ def route_query(self, query: str, include_retrieval: bool = True) -> Dict:
149
+ """
150
+ Route a healthcare query through the classification and routing system.
151
+
152
+ Args:
153
+ query: The user's query text
154
+ include_retrieval: Whether to include retrieval results for medical queries
155
+
156
+ Returns:
157
+ Dictionary with routing decision, confidence, and additional context
158
+ """
159
+
160
+ # Step 1: Medical vs Insurance classification
161
+ medical_prediction = predict_query([query], self.embedding_model, self.classifier_head)
162
+
163
+ # Extract prediction details
164
+ primary_category = medical_prediction['prediction'][0]
165
+ confidence = medical_prediction['confidence'] if isinstance(medical_prediction['confidence'], float) else medical_prediction['confidence'][0]
166
+ probabilities = medical_prediction['probabilities']
167
+
168
+ routing_result = {
169
+ "query": query,
170
+ "primary_classification": primary_category,
171
+ "confidence": confidence,
172
+ "all_probabilities": {
173
+ CATEGORIES[i]: float(probabilities[i]) if isinstance(probabilities[0], list) else float(probabilities[i])
174
+ for i in range(len(CATEGORIES))
175
+ },
176
+ "routing_decision": None,
177
+ "reason_classification": None,
178
+ "retrieval_results": None,
179
+ "recommendations": []
180
+ }
181
+
182
+ # Step 2: Route based on classification
183
+ if primary_category.lower() == "medical":
184
+ routing_result["routing_decision"], routing_result["reason_classification"] = self._route_medical_query(query, include_retrieval)
185
+ else:
186
+ routing_result["routing_decision"] = self._route_insurance_query()
187
+
188
+ # Step 3: Add contextual recommendations
189
+ routing_result["recommendations"] = self._generate_recommendations(
190
+ primary_category, confidence, routing_result.get("reason_classification")
191
+ )
192
+
193
+ return routing_result
194
+
195
+ def _route_medical_query(self, query: str, include_retrieval: bool = True) -> Tuple[Dict, Dict]:
196
+ """Route medical queries through reason classification."""
197
+
198
+ # Get reason classification
199
+ try:
200
+ reason_result = predict_single_reason(query)
201
+ reason_category = reason_result['category']
202
+ reason_confidence = reason_result['confidence']
203
+ reason_probabilities = reason_result['probabilities']
204
+ except Exception as e:
205
+ print(f"Reason classification failed: {e}")
206
+ # Fallback to general medical routing
207
+ reason_category = "ROUTINE_CARE"
208
+ reason_confidence = 0.5
209
+ reason_probabilities = {}
210
+
211
+ # Get department routing based on reason
212
+ routing = self.medical_department_routing.get(
213
+ reason_category,
214
+ self.medical_department_routing["ROUTINE_CARE"]
215
+ ).copy()
216
+
217
+ # Add reason classification details
218
+ reason_classification = {
219
+ "category": reason_category,
220
+ "confidence": reason_confidence,
221
+ "probabilities": reason_probabilities
222
+ }
223
+
224
+ # Add retrieval results if available and requested
225
+ if include_retrieval and self.retriever:
226
+ try:
227
+ retrieval_results = self.retriever.retrieve(query, k=5, for_ui=True)
228
+ routing["retrieval_results"] = retrieval_results
229
+ except Exception as e:
230
+ print(f"Retrieval failed: {e}")
231
+ routing["retrieval_results"] = []
232
+
233
+ return routing, reason_classification
234
+
235
+ def _route_insurance_query(self) -> Dict:
236
+ """Route insurance queries to insurance department."""
237
+ return self.insurance_routing.copy()
238
+
239
+ def _generate_recommendations(self, primary_category: str, confidence: float, reason_classification: Dict = None) -> List[str]:
240
+ """Generate contextual recommendations based on classification."""
241
+
242
+ recommendations = []
243
+
244
+ # Low confidence warning
245
+ if confidence < 0.7:
246
+ recommendations.append(
247
+ "Classification confidence is low. Consider manual review or "
248
+ "asking the user to clarify their request."
249
+ )
250
+
251
+ # Category-specific recommendations
252
+ if primary_category.lower() == "medical":
253
+ recommendations.extend([
254
+ "Consider asking follow-up questions about symptoms",
255
+ "Verify if this requires immediate attention",
256
+ "Check if patient has existing appointments or conditions"
257
+ ])
258
+
259
+ # Reason-specific recommendations
260
+ if reason_classification:
261
+ reason_category = reason_classification.get('category')
262
+ if reason_category == "PAIN_CONDITIONS":
263
+ recommendations.append("Assess pain level and duration for urgency determination")
264
+ elif reason_category == "INJURIES":
265
+ recommendations.append("Determine if immediate medical attention is required")
266
+ elif reason_category == "PROCEDURES":
267
+ recommendations.append("Verify insurance pre-authorization requirements")
268
+
269
+ elif primary_category.lower() == "insurance":
270
+ recommendations.extend([
271
+ "Have patient account information ready",
272
+ "Verify current insurance information and benefits",
273
+ "Prepare to explain coverage details and requirements"
274
+ ])
275
+
276
+ return recommendations
277
+
278
+ def batch_route_queries(self, queries: List[str]) -> List[Dict]:
279
+ """Route multiple queries efficiently."""
280
+ return [self.route_query(query) for query in queries]
281
+
282
+ def get_routing_statistics(self, queries: List[str]) -> Dict:
283
+ """Analyze routing patterns for a batch of queries."""
284
+
285
+ results = self.batch_route_queries(queries)
286
+
287
+ # Count categories
288
+ primary_counts = {}
289
+ reason_counts = {}
290
+ confidence_scores = []
291
+
292
+ for result in results:
293
+ # Primary classification counts
294
+ primary_category = result["primary_classification"]
295
+ primary_counts[primary_category] = primary_counts.get(primary_category, 0) + 1
296
+ confidence_scores.append(result["confidence"])
297
+
298
+ # Reason classification counts (for medical queries)
299
+ if result["reason_classification"]:
300
+ reason_category = result["reason_classification"]["category"]
301
+ reason_counts[reason_category] = reason_counts.get(reason_category, 0) + 1
302
+
303
+ return {
304
+ "total_queries": len(queries),
305
+ "primary_distribution": primary_counts,
306
+ "reason_distribution": reason_counts,
307
+ "average_confidence": sum(confidence_scores) / len(confidence_scores),
308
+ "low_confidence_queries": len([c for c in confidence_scores if c < 0.7]),
309
+ "primary_percentages": {
310
+ cat: (count / len(queries)) * 100
311
+ for cat, count in primary_counts.items()
312
+ },
313
+ "reason_percentages": {
314
+ cat: (count / len(queries)) * 100
315
+ for cat, count in reason_counts.items()
316
+ }
317
+ }
318
+
319
+
320
+ def demo_router():
321
+ """Demonstrate the query router functionality."""
322
+
323
+ print("Initializing Healthcare Query Router...")
324
+ router = HealthcareQueryRouter()
325
+
326
+ # Test queries covering different categories
327
+ test_queries = [
328
+ # Insurance queries
329
+ "My insurance claim was denied, can you help?",
330
+ "What does my insurance cover for this procedure?",
331
+ "I need to verify my insurance benefits",
332
+
333
+ # Medical queries - different reasons
334
+ "I have heel pain when I walk", # PAIN_CONDITIONS
335
+ "I need routine foot care", # ROUTINE_CARE
336
+ "I sprained my ankle playing sports", # INJURIES
337
+ "My toenail is ingrown and infected", # SKIN_CONDITIONS
338
+ "I have flat feet and need evaluation", # STRUCTURAL_ISSUES
339
+ "I need a cortisone injection", # PROCEDURES
340
+ ]
341
+
342
+ print(f"\nRouting {len(test_queries)} test queries...\n")
343
+
344
+ for i, query in enumerate(test_queries, 1):
345
+ print(f"Query {i}: {query}")
346
+
347
+ result = router.route_query(query)
348
+
349
+ print(f" Primary: {result['primary_classification']} "
350
+ f"(confidence: {result['confidence']:.3f})")
351
+
352
+ if result['reason_classification']:
353
+ print(f" Reason: {result['reason_classification']['category']} "
354
+ f"(confidence: {result['reason_classification']['confidence']:.3f})")
355
+
356
+ print(f" Department: {result['routing_decision']['department']}")
357
+ print(f" Priority: {result['routing_decision']['priority']}")
358
+ print(f" Response Time: {result['routing_decision']['estimated_response']}")
359
+
360
+ if result['recommendations']:
361
+ print(f" Recommendation: {result['recommendations'][0]}")
362
+
363
+ print()
364
+
365
+ # Show routing statistics
366
+ print("Routing Statistics:")
367
+ stats = router.get_routing_statistics(test_queries)
368
+
369
+ print("Primary Classification:")
370
+ for category, percentage in stats['primary_percentages'].items():
371
+ print(f" {category}: {percentage:.1f}%")
372
+
373
+ if stats['reason_percentages']:
374
+ print("Reason Classification:")
375
+ for category, percentage in stats['reason_percentages'].items():
376
+ print(f" {category}: {percentage:.1f}%")
377
+
378
+ print(f"Average Confidence: {stats['average_confidence']:.3f}")
379
+ print(f"Low Confidence Queries: {stats['low_confidence_queries']}")
380
+
381
+
382
+ if __name__ == "__main__":
383
+ demo_router()
classifier/reason/README.md ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Healthcare Reason Classification System
2
+
3
+ This module implements a specialized classifier for healthcare visit reasons using real clinic data to classify patient queries into specific healthcare reason categories.
4
+
5
+ ## Overview
6
+
7
+ The reason classifier addresses the challenge of routing medical healthcare queries to appropriate specialized departments. It classifies medical queries into specific reason categories based on actual healthcare visit data.
8
+
9
+ ## Architecture
10
+
11
+ ### Classification Categories
12
+
13
+ | Category | Description | Examples |
14
+ |----------|-------------|----------|
15
+ | `ROUTINE_CARE` | Routine healthcare, maintenance visits, general care | "I need routine foot care", "Regular nail care appointment" |
16
+ | `PAIN_CONDITIONS` | Various pain-related conditions and discomfort | "I have heel pain when I walk", "My ankle is sore" |
17
+ | `INJURIES` | Sprains, wounds, trauma-related conditions | "I sprained my ankle playing sports", "I have a wound that won't heal" |
18
+ | `SKIN_CONDITIONS` | Skin-related issues and conditions | "My toenail is ingrown and infected", "I have calluses on my feet" |
19
+ | `STRUCTURAL_ISSUES` | Structural problems and related conditions | "I have flat feet", "I need evaluation for plantar fasciitis" |
20
+ | `PROCEDURES` | Injections, surgical consultations, post-operative care | "I need a cortisone injection", "Post-surgical follow-up" |
21
+
22
+ ### Technical Implementation
23
+
24
+ - **Base Model**: `sentence-transformers/embeddinggemma-300m-medical`
25
+ - **Architecture**: SetFit with frozen embeddings + trainable classification head
26
+ - **Training**: Real healthcare data from clinic appointment records
27
+ - **Integration**: Works as part of the complete healthcare routing system
28
+
29
+ ## Quick Start
30
+
31
+ ### 1. Train the Classifier
32
+
33
+ ```bash
34
+ # Train with real healthcare data
35
+ python classifier/reason/train_reason.py
36
+
37
+ # The training script will:
38
+ # - Load real healthcare data from data/reason_for_visit_data.xlsx
39
+ # - Map reasons to categories using keyword matching
40
+ # - Train the classifier with frozen embeddings
41
+ # - Save the trained model to classifier/reason_checkpoints/
42
+ ```
43
+
44
+ ### 2. Use the CLI
45
+
46
+ ```bash
47
+ # Classify a single reason query
48
+ python cli/reason_classifier_cli_new.py "I have heel pain when I walk"
49
+
50
+ # Interactive mode
51
+ python cli/reason_classifier_cli_new.py --interactive
52
+
53
+ # Batch processing
54
+ python cli/reason_classifier_cli_new.py --batch queries.txt --output results.json
55
+
56
+ # Use complete healthcare routing system
57
+ python cli/healthcare_classifier_cli.py "I need routine foot care"
58
+ ```
59
+
60
+ ### 3. Programmatic Usage
61
+
62
+ ```python
63
+ from classifier.reason import ReasonClassifier, predict_single_reason
64
+
65
+ # Using the main classifier class
66
+ classifier = ReasonClassifier()
67
+ predictions = classifier.predict(["I have heel pain when I walk"])
68
+ print(predictions[0]['category']) # Output: PAIN_CONDITIONS
69
+
70
+ # Using convenience function
71
+ result = predict_single_reason("I need routine foot care")
72
+ print(result['category']) # Output: ROUTINE_CARE
73
+ print(result['confidence']) # Confidence score
74
+ print(result['probabilities']) # All category probabilities
75
+ ```
76
+
77
+ ## System Integration
78
+
79
+ ### Complete Healthcare Routing Workflow
80
+
81
+ ```
82
+ User Query
83
+
84
+ Medical vs Insurance Classification
85
+
86
+ ┌─────────────────┬─────────────────┐
87
+ │ Insurance │ Medical │
88
+ │ Queries │ Queries │
89
+ │ ↓ │ ↓ │
90
+ │ Insurance │ Reason │
91
+ │ Department │ Classification │
92
+ │ │ ↓ │
93
+ │ │ • ROUTINE_CARE │
94
+ │ │ • PAIN_CONDITIONS │
95
+ │ │ • INJURIES │
96
+ │ │ • SKIN_CONDITIONS │
97
+ │ │ • STRUCTURAL_ISSUES │
98
+ │ │ • PROCEDURES │
99
+ └─────────────────┴─────────────────┘
100
+ ```
101
+
102
+ ### Integration with Healthcare System
103
+
104
+ The reason classifier integrates as part of the complete healthcare routing system:
105
+
106
+ 1. **Primary Classification**: Medical vs Insurance queries
107
+ 2. **Reason Classification**: Medical queries → Specific reason categories
108
+ 3. **Department Routing**: Route to appropriate specialized departments
109
+
110
+ ## Training Data Strategy
111
+
112
+ ### Real Healthcare Data
113
+
114
+ The system uses actual healthcare clinic data:
115
+
116
+ ```python
117
+ # Data source: data/reason_for_visit_data.xlsx
118
+ # Contains real patient visit reasons and appointment types
119
+ # Examples from actual data:
120
+ # - "Heel pain"
121
+ # - "Routine foot care"
122
+ # - "Ingrown toenail"
123
+ # - "Ankle sprain"
124
+ # - "Plantar fasciitis"
125
+ ```
126
+
127
+ ### Category Mapping Strategy
128
+
129
+ The system uses keyword-based mapping to categorize real healthcare reasons:
130
+
131
+ ```python
132
+ def map_reason_to_category(reason: str) -> int:
133
+ reason_lower = reason.lower()
134
+
135
+ # ROUTINE_CARE (routine care, maintenance visits)
136
+ if any(word in reason_lower for word in ['routine', 'nail care', 'calluses']):
137
+ return 0
138
+
139
+ # PAIN_CONDITIONS (various pain-related conditions)
140
+ elif any(word in reason_lower for word in ['pain', 'ache', 'sore']):
141
+ return 1
142
+
143
+ # ... other categories
144
+ ```
145
+
146
+ ## Performance Metrics
147
+
148
+ ### Expected Performance
149
+ - **Accuracy**: Based on real healthcare data patterns
150
+ - **Categories**: 6 specialized healthcare reason categories
151
+ - **Confidence**: Variable based on training data quality
152
+
153
+ ### Evaluation Framework
154
+
155
+ ```bash
156
+ # Train and evaluate the model
157
+ python classifier/reason/train_reason.py
158
+
159
+ # Test the trained model
160
+ python classifier/reason/infer_reason.py
161
+
162
+ # Results include:
163
+ # - Training metrics
164
+ # - Category distribution
165
+ # - Example predictions with confidence scores
166
+ ```
167
+
168
+ ## File Structure
169
+
170
+ ```
171
+ classifier/reason/
172
+ ├── __init__.py # Package initialization and exports
173
+ ├── README.md # This documentation
174
+ ├── reason_classifier.py # Main ReasonClassifier class
175
+ ├── infer_reason.py # Inference functions and utilities
176
+ └── train_reason.py # Training script and functions
177
+ ```
178
+
179
+ ## API Reference
180
+
181
+ ### ReasonClassifier
182
+
183
+ ```python
184
+ class ReasonClassifier:
185
+ def __init__(self, data_file: str = "data/reason_for_visit_data.xlsx")
186
+ def predict(self, queries: List[str]) -> List[Dict]
187
+ def train(self, train_data: pd.DataFrame = None, eval_data: Optional[pd.DataFrame] = None)
188
+ def save_model(self, path: str)
189
+ def load_model(self, path: str)
190
+ def create_real_dataset(self) -> pd.DataFrame
191
+ def analyze_real_data(self)
192
+ ```
193
+
194
+ ### Inference Functions
195
+
196
+ ```python
197
+ def predict_single_reason(query: str) -> dict
198
+ def predict_reason_query(text: list[str], embedding_model, classifier_head) -> dict
199
+ def get_reason_models() -> tuple
200
+ def test_reason_classifier()
201
+ ```
202
+
203
+ ### Training Functions
204
+
205
+ ```python
206
+ def get_reason_model(num_classes: int)
207
+ def get_reason_dataset() -> pd.DataFrame
208
+ def map_reason_to_category(reason: str) -> int
209
+ def preprocess_reason_data(df: pd.DataFrame) -> pd.DataFrame
210
+ ```
211
+
212
+ ## Data Requirements
213
+
214
+ ### Healthcare Data Format
215
+
216
+ The system expects healthcare data in Excel format with these columns:
217
+
218
+ ```
219
+ Required columns:
220
+ - "Reason For Visit": The primary reason for the healthcare visit
221
+ - "Appointment Type": Type of appointment (optional, used for context)
222
+
223
+ Example data:
224
+ | Reason For Visit | Appointment Type |
225
+ |------------------|------------------|
226
+ | Heel pain | Follow-up |
227
+ | Routine foot care| Maintenance |
228
+ | Ingrown toenail | New Patient |
229
+ ```
230
+
231
+ ## Deployment Considerations
232
+
233
+ ### Production Readiness
234
+
235
+ 1. **Model Persistence**: Trained models saved with timestamps in `classifier/reason_checkpoints/`
236
+ 2. **Error Handling**: Graceful fallbacks for prediction failures
237
+ 3. **Real Data Integration**: Uses actual healthcare clinic data
238
+ 4. **Device Support**: CPU/GPU/MPS compatibility
239
+
240
+ ### Scalability
241
+
242
+ - **Batch Processing**: Efficient handling of multiple queries
243
+ - **Integration**: Works with existing healthcare routing system
244
+ - **Checkpoints**: Automatic model saving with timestamps
245
+
246
+ ## Future Enhancements
247
+
248
+ ### Data Improvements
249
+
250
+ 1. **Expanded Dataset**: Include more healthcare specialties
251
+ 2. **Active Learning**: Improve model with real-world feedback
252
+ 3. **Multi-language Support**: Support for non-English healthcare queries
253
+
254
+ ### Advanced Features
255
+
256
+ 1. **Confidence Calibration**: Improve confidence score reliability
257
+ 2. **Hierarchical Classification**: Sub-categories within reason types
258
+ 3. **Context Awareness**: Consider patient history and appointment context
259
+
260
+ ## Troubleshooting
261
+
262
+ ### Common Issues
263
+
264
+ 1. **Data Loading Errors**: Ensure `data/reason_for_visit_data.xlsx` exists
265
+ 2. **Low Confidence**: May indicate need for more training data or model retraining
266
+ 3. **Import Errors**: Ensure all dependencies are installed and paths are correct
267
+
268
+ ### Debug Mode
269
+
270
+ ```python
271
+ # Test the classifier with sample queries
272
+ from classifier.reason.infer_reason import test_reason_classifier
273
+ test_reason_classifier()
274
+
275
+ # Check model predictions with probabilities
276
+ from classifier.reason import predict_single_reason
277
+ result = predict_single_reason("ambiguous query")
278
+ print(result['probabilities'])
279
+ ```
280
+
281
+ ### Model Training Issues
282
+
283
+ ```bash
284
+ # Check if healthcare data is available
285
+ ls -la data/reason_for_visit_data.xlsx
286
+
287
+ # Verify model training
288
+ python classifier/reason/train_reason.py
289
+
290
+ # Test inference after training
291
+ python classifier/reason/infer_reason.py
292
+ ```
293
+
294
+ ## Contributing
295
+
296
+ ### Adding New Categories
297
+
298
+ 1. Update `REASON_CATEGORIES` in `reason_classifier.py`, `infer_reason.py`, and `train_reason.py`
299
+ 2. Update category mapping logic in `map_reason_to_category()`
300
+ 3. Retrain the model with new categories
301
+ 4. Update documentation and examples
302
+
303
+ ### Improving Training Data
304
+
305
+ 1. Add more real healthcare examples to the dataset
306
+ 2. Improve keyword mapping for better categorization
307
+ 3. Implement more sophisticated NLP techniques for category assignment
308
+
309
+ ## License
310
+
311
+ This module is part of the health-query-classifier project and follows the same licensing terms.
classifier/reason/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reason classification module for healthcare queries.
3
+
4
+ This module contains components for classifying healthcare visit reasons
5
+ into predefined categories based on real healthcare data.
6
+ """
7
+
8
+ from .reason_classifier import ReasonClassifier, REASON_CATEGORIES
9
+ from .infer_reason import predict_reason_query, predict_single_reason, get_reason_models
10
+
11
+ __all__ = [
12
+ 'ReasonClassifier',
13
+ 'REASON_CATEGORIES',
14
+ 'predict_reason_query',
15
+ 'predict_single_reason',
16
+ 'get_reason_models'
17
+ ]
classifier/reason/infer_reason.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference module for Healthcare Reason Classification
3
+
4
+ This module provides inference for the reason classification system,
5
+ separate from the medical/insurance classifier.
6
+ """
7
+
8
+ from ..head import ClassifierHead
9
+ from datetime import datetime
10
+ import os
11
+ import pprint
12
+ import torch
13
+ from sentence_transformers import SentenceTransformer
14
+
15
+ # Reason-specific configuration
16
+ REASON_CATEGORIES = {
17
+ 0: "ROUTINE_CARE",
18
+ 1: "PAIN_CONDITIONS",
19
+ 2: "INJURIES",
20
+ 3: "SKIN_CONDITIONS",
21
+ 4: "STRUCTURAL_ISSUES",
22
+ 5: "PROCEDURES"
23
+ }
24
+
25
+ REASON_CHECKPOINT_PATH = "classifier/reason_checkpoints"
26
+ DATETIME_FORMAT = "%Y%m%d_%H%M%S"
27
+ MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
28
+
29
+ def get_device():
30
+ """Get the best available device for inference."""
31
+ if torch.backends.mps.is_available():
32
+ return torch.device("mps")
33
+ elif torch.cuda.is_available():
34
+ return torch.device("cuda")
35
+ else:
36
+ return torch.device("cpu")
37
+
38
+ DEVICE = get_device()
39
+
40
+ def get_reason_models():
41
+ """Get the embedding model and classifier head for reason inference."""
42
+ # Load embedding model
43
+ embedding_model = SentenceTransformer(
44
+ MODEL_NAME,
45
+ prompts={
46
+ 'classification': 'task: healthcare reason classification | query: ',
47
+ 'retrieval (query)': 'task: search result | query: ',
48
+ 'retrieval (document)': 'title: {title | "none"} | text: ',
49
+ },
50
+ default_prompt_name='classification',
51
+ )
52
+
53
+ # Load classifier head (for 6 reason categories)
54
+ classifier_head = ClassifierHead(len(REASON_CATEGORIES))
55
+
56
+ return embedding_model.to(DEVICE), classifier_head.to(DEVICE)
57
+
58
+ def predict_reason_query(
59
+ text: list[str],
60
+ embedding_model: SentenceTransformer,
61
+ classifier_head: ClassifierHead,
62
+ ) -> dict:
63
+ """
64
+ Runs the full inference pipeline for reason classification: Text -> Embedding -> Classification.
65
+ """
66
+ # Set models to evaluation mode
67
+ embedding_model.eval()
68
+ classifier_head.eval()
69
+
70
+ with torch.no_grad():
71
+ # Embed the text
72
+ embeddings = embedding_model.encode(
73
+ text,
74
+ convert_to_tensor=True,
75
+ device=DEVICE
76
+ ).to(DEVICE)
77
+
78
+ # Calculate probabilities and prediction
79
+ probabilities = classifier_head.predict_proba(embeddings)
80
+
81
+ # Get the predicted index and confidence
82
+ predicted_indices = torch.argmax(probabilities, dim=1)
83
+
84
+ # Convert tensors to Python types safely
85
+ if predicted_indices.dim() == 0: # Single prediction
86
+ predicted_indices = [predicted_indices.item()]
87
+ else:
88
+ predicted_indices = predicted_indices.cpu().tolist()
89
+
90
+ # Get confidences
91
+ confidences = []
92
+ for i, idx in enumerate(predicted_indices):
93
+ conf = probabilities[i][idx].item() if probabilities.dim() > 1 else probabilities[idx].item()
94
+ confidences.append(conf)
95
+
96
+ # Get the predicted label names
97
+ predicted_labels = [REASON_CATEGORIES[i] for i in predicted_indices]
98
+
99
+ return {
100
+ 'prediction': predicted_labels,
101
+ 'confidence': confidences,
102
+ 'probabilities': probabilities.cpu().tolist()
103
+ }
104
+
105
+ def predict_single_reason(query: str) -> dict:
106
+ """Convenience function to predict a single reason query."""
107
+ try:
108
+ embedding_model, classifier_head = get_reason_models()
109
+
110
+ # Try to load the most recent trained checkpoint
111
+ if os.path.exists(REASON_CHECKPOINT_PATH):
112
+ for d in os.listdir(REASON_CHECKPOINT_PATH):
113
+ if d.endswith('.pt'):
114
+ checkpoint_path = f"{REASON_CHECKPOINT_PATH}/{d}"
115
+ try:
116
+ state_dict = torch.load(checkpoint_path, weights_only=True, map_location=DEVICE)
117
+ classifier_head.load_state_dict(state_dict)
118
+ print(f"Loaded trained weights from {checkpoint_path}")
119
+ break
120
+ except Exception as e:
121
+ print(f"Could not load weights from {checkpoint_path}: {e}")
122
+
123
+ result = predict_reason_query([query], embedding_model, classifier_head)
124
+
125
+ # Extract values safely
126
+ prediction = result['prediction'][0] if isinstance(result['prediction'], list) else str(result['prediction'])
127
+ confidence = result['confidence'] if isinstance(result['confidence'], float) else (result['confidence'][0] if isinstance(result['confidence'], list) else float(result['confidence']))
128
+
129
+ # Handle probabilities - ensure it's a list
130
+ probabilities = result['probabilities']
131
+ if isinstance(probabilities, list) and len(probabilities) > 0:
132
+ if isinstance(probabilities[0], list):
133
+ probabilities = probabilities[0]
134
+
135
+ # Create probability dictionary
136
+ prob_dict = {}
137
+ for i, category in REASON_CATEGORIES.items():
138
+ if i < len(probabilities):
139
+ prob_dict[category] = float(probabilities[i])
140
+ else:
141
+ prob_dict[category] = 0.0
142
+
143
+ return {
144
+ 'query': query,
145
+ 'category': prediction,
146
+ 'confidence': confidence,
147
+ 'probabilities': prob_dict
148
+ }
149
+ except Exception as e:
150
+ # Return a default classification if the model fails
151
+ return {
152
+ 'query': query,
153
+ 'category': 'GENERAL_MEDICAL',
154
+ 'confidence': 0.5,
155
+ 'probabilities': {category: 1.0/len(REASON_CATEGORIES) for category in REASON_CATEGORIES.values()},
156
+ 'error': str(e)
157
+ }
158
+
159
+ def test_reason_classifier():
160
+ """Test the reason classifier with sample queries."""
161
+ latest = None
162
+ path = ""
163
+
164
+ # Try to load the most recent checkpoint
165
+ if os.path.exists(REASON_CHECKPOINT_PATH):
166
+ for d in os.listdir(REASON_CHECKPOINT_PATH):
167
+ if d.endswith('.pt'):
168
+ checkpoint_path = f"{REASON_CHECKPOINT_PATH}/{d}"
169
+ print(f"Found checkpoint: {checkpoint_path}")
170
+ path = checkpoint_path
171
+ break
172
+
173
+ if not path:
174
+ print("No trained checkpoints found. Using untrained model.")
175
+ else:
176
+ print("No checkpoint directory found. Using untrained model.")
177
+
178
+ embedding_model, classifier = get_reason_models()
179
+
180
+ # Load trained weights if available
181
+ if path and os.path.exists(path):
182
+ try:
183
+ state_dict = torch.load(path, weights_only=True, map_location=DEVICE)
184
+ classifier.load_state_dict(state_dict)
185
+ print(f"Loaded trained weights from {path}")
186
+ except Exception as e:
187
+ print(f"Could not load weights: {e}. Using untrained model.")
188
+
189
+ # Test queries for reason classification
190
+ queries = [
191
+ "I have heel pain when I walk",
192
+ "My toenail is ingrown and painful",
193
+ "I need routine foot care",
194
+ "I sprained my ankle playing sports",
195
+ "I have plantar fasciitis",
196
+ "I need a cortisone injection"
197
+ ]
198
+
199
+ print("\nTesting reason classification:")
200
+ pred = predict_reason_query(
201
+ text=queries,
202
+ embedding_model=embedding_model,
203
+ classifier_head=classifier,
204
+ )
205
+
206
+ pprint.pprint(pred, indent=4)
207
+
208
+ if __name__ == "__main__":
209
+ test_reason_classifier()
classifier/reason/reason_classifier.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Healthcare Reason for Visit Classifier
3
+
4
+ This module implements a classifier for healthcare clinic queries
5
+ using real healthcare data from clinic appointment records.
6
+
7
+ Categories based on the actual data:
8
+ - ROUTINE_CARE: Routine care, maintenance visits
9
+ - PAIN_CONDITIONS: Various pain-related conditions
10
+ - INJURIES: Sprains, wounds, trauma-related visits
11
+ - SKIN_CONDITIONS: Skin-related conditions and issues
12
+ - STRUCTURAL_ISSUES: Structural problems and conditions
13
+ - PROCEDURES: Injections, surgical consults, postop care
14
+ """
15
+
16
+ import os
17
+ import torch
18
+ import pandas as pd
19
+ import numpy as np
20
+ from typing import List, Dict, Tuple, Optional
21
+ from sentence_transformers import SentenceTransformer
22
+ from setfit import SetFitModel
23
+ from sklearn.model_selection import train_test_split
24
+ from sklearn.metrics import classification_report, confusion_matrix
25
+ from datasets import Dataset
26
+ import json
27
+
28
+ from ..head import ClassifierHead
29
+
30
+ # Healthcare reason categories based on real data analysis
31
+ REASON_CATEGORIES = {
32
+ 0: "ROUTINE_CARE",
33
+ 1: "PAIN_CONDITIONS",
34
+ 2: "INJURIES",
35
+ 3: "SKIN_CONDITIONS",
36
+ 4: "STRUCTURAL_ISSUES",
37
+ 5: "PROCEDURES"
38
+ }
39
+
40
+ CATEGORY_DESCRIPTIONS = {
41
+ "ROUTINE_CARE": "Routine healthcare, maintenance visits, general care",
42
+ "PAIN_CONDITIONS": "Various pain-related conditions and discomfort",
43
+ "INJURIES": "Sprains, wounds, trauma-related conditions",
44
+ "SKIN_CONDITIONS": "Skin-related issues and conditions",
45
+ "STRUCTURAL_ISSUES": "Structural problems and related conditions",
46
+ "PROCEDURES": "Injections, surgical consultations, post-operative care"
47
+ }
48
+
49
+ class ReasonClassifier:
50
+ """
51
+ Healthcare Reason Classifier that uses real clinic data to classify
52
+ patient queries into specific healthcare reason categories.
53
+ """
54
+
55
+ def __init__(self, data_file: str = "data/reason_for_visit_data.xlsx"):
56
+ self.model_name = "sentence-transformers/embeddinggemma-300m-medical"
57
+ self.num_classes = len(REASON_CATEGORIES)
58
+ self.categories = REASON_CATEGORIES
59
+ self.data_file = data_file
60
+ self.model = None
61
+ self.device = self._get_device()
62
+
63
+ # Load and process real data
64
+ self.healthcare_df = self._load_data()
65
+ self._initialize_model()
66
+
67
+ def _get_device(self):
68
+ """Get the best available device for training/inference."""
69
+ if torch.backends.mps.is_available():
70
+ return torch.device("mps")
71
+ elif torch.cuda.is_available():
72
+ return torch.device("cuda")
73
+ else:
74
+ return torch.device("cpu")
75
+
76
+ def _load_data(self) -> pd.DataFrame:
77
+ """Load the real healthcare dataset."""
78
+ try:
79
+ df = pd.read_excel(self.data_file)
80
+ print(f"Loaded {len(df)} healthcare records from {self.data_file}")
81
+ print(f"Unique reasons: {df['Reason For Visit'].nunique()}")
82
+ return df
83
+ except Exception as e:
84
+ print(f"Error loading data: {e}")
85
+ raise RuntimeError(f"Failed to load healthcare data from {self.data_file}")
86
+
87
+ def _initialize_model(self):
88
+ """Initialize the model with the existing infrastructure."""
89
+ try:
90
+ model_body = SentenceTransformer(
91
+ self.model_name,
92
+ prompts={
93
+ 'classification': 'task: healthcare reason classification | query: ',
94
+ 'retrieval (query)': 'task: search result | query: ',
95
+ 'retrieval (document)': 'title: {title | "none"} | text: ',
96
+ },
97
+ default_prompt_name='classification',
98
+ )
99
+
100
+ model_head = ClassifierHead(self.num_classes, embedding_dim=768)
101
+ self.model = SetFitModel(model_body, model_head)
102
+ self.model.freeze("body") # Freeze embedding weights
103
+ self.model = self.model.to(self.device)
104
+
105
+ print(f"Initialized ReasonClassifier on {self.device}")
106
+
107
+ except Exception as e:
108
+ print(f"Error initializing model: {e}")
109
+ raise RuntimeError("Failed to initialize reason classifier")
110
+
111
+ def _map_reason_to_category(self, reason: str) -> int:
112
+ """
113
+ Map real healthcare reasons to categories using keyword matching.
114
+ Based on the actual data distribution.
115
+ """
116
+ reason_lower = reason.lower()
117
+
118
+ # ROUTINE_CARE (routine foot care, nail care, calluses)
119
+ if any(word in reason_lower for word in ['routine', 'nail care', 'calluses']):
120
+ return 0
121
+
122
+ # PAIN_CONDITIONS (heel pain, ankle pain, foot pain, etc.)
123
+ if any(word in reason_lower for word in ['pain', 'ache', 'sore']):
124
+ return 1
125
+
126
+ # INJURIES (ankle sprain, wounds, trauma)
127
+ if any(word in reason_lower for word in ['sprain', 'wound', 'injury', 'trauma']):
128
+ return 2
129
+
130
+ # SKIN_CONDITIONS (ingrown toenail, calluses, skin issues)
131
+ if any(word in reason_lower for word in ['ingrown', 'toenail', 'callus', 'skin']):
132
+ return 3
133
+
134
+ # STRUCTURAL_ISSUES (flat feet, plantar fasciitis, achilles)
135
+ if any(word in reason_lower for word in ['flat feet', 'plantar', 'fasciitis', 'achilles', 'tendon']):
136
+ return 4
137
+
138
+ # PROCEDURES (injection, surgical consult, postop)
139
+ if any(word in reason_lower for word in ['injection', 'surgical', 'consult', 'postop', 'procedure']):
140
+ return 5
141
+
142
+ # Default to pain conditions (most common category)
143
+ return 1
144
+
145
+ def create_real_dataset(self) -> pd.DataFrame:
146
+ """
147
+ Create training dataset from real healthcare data.
148
+ """
149
+
150
+ training_data = []
151
+
152
+ for _, row in self.healthcare_df.iterrows():
153
+ reason = row['Reason For Visit']
154
+ appointment_type = row['Appointment Type']
155
+
156
+ # Map reason to category
157
+ category_id = self._map_reason_to_category(reason)
158
+
159
+ # Create enhanced text with context
160
+ enhanced_text = reason
161
+ if pd.notna(appointment_type):
162
+ enhanced_text += f" | {appointment_type}"
163
+
164
+ training_data.append({
165
+ 'text': enhanced_text,
166
+ 'label': category_id,
167
+ 'category': self.categories[category_id],
168
+ 'original_reason': reason,
169
+ 'appointment_type': appointment_type
170
+ })
171
+
172
+ df = pd.DataFrame(training_data)
173
+
174
+ # Show category distribution
175
+ print("\nCategory distribution in training data:")
176
+ for cat_id, cat_name in self.categories.items():
177
+ count = len(df[df['label'] == cat_id])
178
+ percentage = (count / len(df)) * 100
179
+ print(f" {cat_name}: {count} samples ({percentage:.1f}%)")
180
+
181
+ return df.sample(frac=1).reset_index(drop=True) # Shuffle
182
+
183
+ def train(self, train_data: pd.DataFrame = None, eval_data: Optional[pd.DataFrame] = None,
184
+ epochs: int = 16, output_dir: str = "classifier/reason_checkpoints"):
185
+ """Train the healthcare reason classifier."""
186
+
187
+ if train_data is None:
188
+ train_data = self.create_real_dataset()
189
+
190
+ if eval_data is None:
191
+ train_data, eval_data = train_test_split(train_data, test_size=0.2,
192
+ stratify=train_data['label'],
193
+ random_state=42)
194
+
195
+ train_dataset = Dataset.from_pandas(train_data)
196
+ eval_dataset = Dataset.from_pandas(eval_data)
197
+
198
+ from setfit import Trainer, TrainingArguments
199
+
200
+ args = TrainingArguments(
201
+ output_dir=output_dir,
202
+ num_epochs=(0, epochs), # Skip contrastive learning, only train head
203
+ eval_strategy='epoch',
204
+ eval_steps=100,
205
+ save_strategy='epoch',
206
+ logging_steps=50,
207
+ )
208
+
209
+ trainer = Trainer(
210
+ model=self.model,
211
+ train_dataset=train_dataset,
212
+ eval_dataset=eval_dataset,
213
+ metric='accuracy',
214
+ column_mapping={"text": "text", "label": "label"},
215
+ args=args,
216
+ )
217
+
218
+ print("Starting training...")
219
+ trainer.train()
220
+
221
+ # Evaluate
222
+ metrics = trainer.evaluate(eval_dataset)
223
+ print(f"Training completed. Final metrics: {metrics}")
224
+
225
+ return metrics
226
+
227
+ def predict(self, queries: List[str]) -> List[Dict]:
228
+ """
229
+ Predict healthcare reason categories for a list of queries.
230
+
231
+ Returns:
232
+ List of dictionaries with 'query', 'category', 'confidence', 'probabilities'
233
+ """
234
+ if not self.model:
235
+ raise RuntimeError("Model not initialized. Train or load a model first.")
236
+
237
+ predictions = []
238
+
239
+ for query in queries:
240
+ # Get prediction using SetFit's built-in methods
241
+ pred_label = self.model.predict([query])[0]
242
+ pred_proba = self.model.predict_proba([query])[0]
243
+
244
+ category = self.categories[int(pred_label)]
245
+ confidence = float(pred_proba[int(pred_label)])
246
+
247
+ predictions.append({
248
+ 'query': query,
249
+ 'category': category,
250
+ 'confidence': confidence,
251
+ 'probabilities': {self.categories[i]: float(prob)
252
+ for i, prob in enumerate(pred_proba)}
253
+ })
254
+
255
+ return predictions
256
+
257
+ def save_model(self, path: str):
258
+ """Save the trained model."""
259
+ os.makedirs(os.path.dirname(path), exist_ok=True)
260
+ self.model.save_pretrained(path)
261
+
262
+ # Save category mapping
263
+ with open(os.path.join(path, 'categories.json'), 'w') as f:
264
+ json.dump(self.categories, f)
265
+
266
+ print(f"Model saved to {path}")
267
+
268
+ def load_model(self, path: str):
269
+ """Load a trained model."""
270
+ self.model = SetFitModel.from_pretrained(path)
271
+ self.model = self.model.to(self.device)
272
+
273
+ # Load category mapping
274
+ with open(os.path.join(path, 'categories.json'), 'r') as f:
275
+ self.categories = {int(k): v for k, v in json.load(f).items()}
276
+
277
+ print(f"Model loaded from {path}")
278
+
279
+ def evaluate_on_test_set(self, test_data: pd.DataFrame) -> Dict:
280
+ """Evaluate the model on a test dataset."""
281
+ predictions = self.predict(test_data['text'].tolist())
282
+
283
+ y_true = test_data['label'].tolist()
284
+ y_pred = [list(self.categories.keys())[list(self.categories.values()).index(p['category'])]
285
+ for p in predictions]
286
+
287
+ # Classification report
288
+ report = classification_report(y_true, y_pred,
289
+ target_names=list(self.categories.values()),
290
+ output_dict=True)
291
+
292
+ # Confusion matrix
293
+ cm = confusion_matrix(y_true, y_pred)
294
+
295
+ return {
296
+ 'classification_report': report,
297
+ 'confusion_matrix': cm.tolist(),
298
+ 'accuracy': report['accuracy']
299
+ }
300
+
301
+ def analyze_real_data(self):
302
+ """Analyze the real healthcare data to understand patterns."""
303
+ print("Real Data Analysis:")
304
+ print("=" * 50)
305
+
306
+ print(f"Total records: {len(self.healthcare_df)}")
307
+ print(f"Unique reasons: {self.healthcare_df['Reason For Visit'].nunique()}")
308
+
309
+ print("\nTop 15 reasons for visit:")
310
+ top_reasons = self.healthcare_df['Reason For Visit'].value_counts().head(15)
311
+ for reason, count in top_reasons.items():
312
+ category_id = self._map_reason_to_category(reason)
313
+ category_name = self.categories[category_id]
314
+ print(f" {reason}: {count} ({category_name})")
315
+
316
+ print(f"\nAppointment types:")
317
+ print(self.healthcare_df['Appointment Type'].value_counts())
318
+
319
+
320
+ def main():
321
+ """Example usage and training script for healthcare reason data."""
322
+ print("Initializing Healthcare Reason Classifier...")
323
+
324
+ # Initialize classifier with real data
325
+ classifier = ReasonClassifier()
326
+
327
+ # Analyze the real data
328
+ classifier.analyze_real_data()
329
+
330
+ # Create training dataset from real data
331
+ print("\nCreating training dataset from real healthcare data...")
332
+ dataset = classifier.create_real_dataset()
333
+
334
+ print(f"Dataset created with {len(dataset)} real examples")
335
+
336
+ # Train the model
337
+ print("\nTraining classifier...")
338
+ metrics = classifier.train(dataset, epochs=20)
339
+
340
+ # Save the model
341
+ model_path = "classifier/reason_model"
342
+ classifier.save_model(model_path)
343
+
344
+ # Test predictions on healthcare reason queries
345
+ test_queries = [
346
+ "I have heel pain when I walk",
347
+ "My toenail is ingrown and painful",
348
+ "I need routine foot care",
349
+ "I sprained my ankle playing sports",
350
+ "I have flat feet and need evaluation",
351
+ "I need a cortisone injection for my foot",
352
+ "I have plantar fasciitis",
353
+ "My foot wound is not healing"
354
+ ]
355
+
356
+ print("\nTesting predictions on healthcare reason queries:")
357
+ predictions = classifier.predict(test_queries)
358
+
359
+ for pred in predictions:
360
+ print(f"Query: {pred['query']}")
361
+ print(f"Category: {pred['category']} (confidence: {pred['confidence']:.3f})")
362
+ print("---")
363
+
364
+
365
+ if __name__ == "__main__":
366
+ main()
classifier/reason/train_reason.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for Healthcare Reason Classification
3
+
4
+ This script trains a classifier for healthcare visit reasons using real
5
+ healthcare data. It creates a separate system from the medical/insurance
6
+ classifier.
7
+ """
8
+
9
+ from sentence_transformers import SentenceTransformer
10
+ from setfit import SetFitModel, Trainer, TrainingArguments
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ # Add project root to path for imports
15
+ REPO_ROOT = Path(__file__).resolve().parents[2]
16
+ if str(REPO_ROOT) not in sys.path:
17
+ sys.path.insert(0, str(REPO_ROOT))
18
+
19
+ from classifier.head import ClassifierHead
20
+ import os
21
+ import pandas as pd
22
+ from sklearn.model_selection import train_test_split
23
+ from datasets import Dataset
24
+ import torch
25
+ from pathlib import Path
26
+ from datetime import datetime
27
+
28
+ # Reason-specific configuration
29
+ REASON_CATEGORIES = {
30
+ 0: "ROUTINE_CARE",
31
+ 1: "PAIN_CONDITIONS",
32
+ 2: "INJURIES",
33
+ 3: "SKIN_CONDITIONS",
34
+ 4: "STRUCTURAL_ISSUES",
35
+ 5: "PROCEDURES"
36
+ }
37
+
38
+ REASON_CHECKPOINT_PATH = "classifier/reason_checkpoints"
39
+ HEALTHCARE_DATA_PATH = "data/reason_for_visit_data.xlsx"
40
+ MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
41
+
42
+ def get_device():
43
+ """Get the best available device for training/inference."""
44
+ if torch.backends.mps.is_available():
45
+ return torch.device("mps")
46
+ elif torch.cuda.is_available():
47
+ return torch.device("cuda")
48
+ else:
49
+ return torch.device("cpu")
50
+
51
+ def get_reason_model(num_classes: int):
52
+ """Get model for reason classification."""
53
+ try:
54
+ model_body = SentenceTransformer(
55
+ MODEL_NAME,
56
+ prompts={
57
+ 'classification': 'task: healthcare reason classification | query: ',
58
+ 'retrieval (query)': 'task: search result | query: ',
59
+ 'retrieval (document)': 'title: {title | "none"} | text: ',
60
+ },
61
+ default_prompt_name='classification',
62
+ )
63
+ # Freeze weights of embedding model
64
+ model_head = ClassifierHead(num_classes)
65
+ model = SetFitModel(model_body, model_head)
66
+ model.freeze("body")
67
+
68
+ except Exception as e:
69
+ print(f"Error loading model {MODEL_NAME}: {e}")
70
+ raise RuntimeError("Failed to load the embedding model.")
71
+
72
+ device = get_device()
73
+ print(f"Using device: {device}")
74
+ return model.to(device)
75
+
76
+ def get_reason_dataset() -> pd.DataFrame:
77
+ """Load the healthcare reason dataset from Excel file."""
78
+ try:
79
+ if not os.path.exists(HEALTHCARE_DATA_PATH):
80
+ raise FileNotFoundError(f"Healthcare data file not found: {HEALTHCARE_DATA_PATH}")
81
+
82
+ print(f"Loading healthcare data from {HEALTHCARE_DATA_PATH}...")
83
+ df = pd.read_excel(HEALTHCARE_DATA_PATH)
84
+ print(f"Loaded {len(df)} healthcare records")
85
+ return df
86
+
87
+ except Exception as e:
88
+ print(f"Error loading healthcare dataset: {e}")
89
+ raise Exception(f"Failed to load healthcare data: {e}")
90
+
91
+ def map_reason_to_category(reason: str) -> int:
92
+ """Map healthcare reasons to categories using keyword matching."""
93
+ reason_lower = reason.lower()
94
+
95
+ # ROUTINE_CARE (routine care, maintenance visits)
96
+ if any(word in reason_lower for word in ['routine', 'nail care', 'calluses', 'maintenance']):
97
+ return 0
98
+
99
+ # PAIN_CONDITIONS (various pain-related conditions)
100
+ elif any(word in reason_lower for word in ['pain', 'ache', 'sore', 'hurt']):
101
+ return 1
102
+
103
+ # INJURIES (sprains, wounds, trauma)
104
+ elif any(word in reason_lower for word in ['sprain', 'wound', 'injury', 'trauma', 'cut', 'bruise']):
105
+ return 2
106
+
107
+ # SKIN_CONDITIONS (skin-related issues)
108
+ elif any(word in reason_lower for word in ['ingrown', 'toenail', 'callus', 'corn', 'skin']):
109
+ return 3
110
+
111
+ # STRUCTURAL_ISSUES (structural problems)
112
+ elif any(word in reason_lower for word in ['flat feet', 'plantar', 'fasciitis', 'achilles', 'tendon', 'arch']):
113
+ return 4
114
+
115
+ # PROCEDURES (injections, surgical consultations)
116
+ elif any(word in reason_lower for word in ['injection', 'surgical', 'consult', 'postop', 'surgery', 'procedure']):
117
+ return 5
118
+
119
+ # Default to pain conditions (most common category)
120
+ else:
121
+ return 1
122
+
123
+ def preprocess_reason_data(df: pd.DataFrame) -> pd.DataFrame:
124
+ """Preprocess the healthcare reason dataset for training."""
125
+ training_data = []
126
+
127
+ for _, row in df.iterrows():
128
+ reason = row['Reason For Visit']
129
+ appointment_type = row.get('Appointment Type', '')
130
+
131
+ # Map reason to category using keyword matching
132
+ category_id = map_reason_to_category(reason)
133
+
134
+ # Create enhanced text with context
135
+ enhanced_text = reason
136
+ if pd.notna(appointment_type) and appointment_type:
137
+ enhanced_text += f" | {appointment_type}"
138
+
139
+ training_data.append({
140
+ 'text': enhanced_text,
141
+ 'label': category_id,
142
+ 'category': REASON_CATEGORIES[category_id],
143
+ 'original_reason': reason
144
+ })
145
+
146
+ processed_df = pd.DataFrame(training_data)
147
+
148
+ # Show category distribution
149
+ print("\nReason category distribution in training data:")
150
+ for cat_id, cat_name in REASON_CATEGORIES.items():
151
+ count = len(processed_df[processed_df['label'] == cat_id])
152
+ percentage = (count / len(processed_df)) * 100
153
+ print(f" {cat_name}: {count} samples ({percentage:.1f}%)")
154
+
155
+ return processed_df
156
+
157
+ def main():
158
+ print("Healthcare Reason Classification - Training Pipeline")
159
+ print("=" * 60)
160
+
161
+ # Load and preprocess data
162
+ df = get_reason_dataset()
163
+ df = preprocess_reason_data(df)
164
+
165
+ # Get model
166
+ model = get_reason_model(len(REASON_CATEGORIES))
167
+
168
+ # Split data
169
+ train, test = train_test_split(
170
+ df, test_size=0.2, stratify=df['label'], random_state=42
171
+ )
172
+
173
+ print(f"\nData split:")
174
+ print(f" Training: {len(train)} samples")
175
+ print(f" Testing: {len(test)} samples")
176
+
177
+ train_dataset = Dataset.from_pandas(train)
178
+ test_dataset = Dataset.from_pandas(test)
179
+
180
+ # Ensure checkpoint directory exists
181
+ Path(REASON_CHECKPOINT_PATH).mkdir(parents=True, exist_ok=True)
182
+
183
+ # Training arguments
184
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
185
+ output_dir = f"{REASON_CHECKPOINT_PATH}/training_{timestamp}"
186
+
187
+ args = TrainingArguments(
188
+ output_dir=output_dir,
189
+ # Skip contrastive fine-tuning (body is frozen)
190
+ num_epochs=(0, 20),
191
+ eval_strategy='epoch',
192
+ eval_steps=100,
193
+ save_strategy='epoch',
194
+ logging_steps=50,
195
+ load_best_model_at_end=True,
196
+ metric_for_best_model='accuracy',
197
+ )
198
+
199
+ trainer = Trainer(
200
+ model=model,
201
+ train_dataset=train_dataset,
202
+ eval_dataset=test_dataset,
203
+ metric='accuracy',
204
+ column_mapping={"text": "text", "label": "label"},
205
+ args=args,
206
+ )
207
+
208
+ print("\nStarting reason classification training...")
209
+ trainer.train()
210
+
211
+ # Evaluate
212
+ print("\nEvaluating reason classification model...")
213
+ metrics = trainer.evaluate(test_dataset)
214
+ print(f"Final evaluation metrics: {metrics}")
215
+
216
+ # Save the trained classifier head
217
+ model_save_path = f"{REASON_CHECKPOINT_PATH}/reason_classifier_head_{timestamp}.pt"
218
+ torch.save(model.model_head.state_dict(), model_save_path)
219
+ print(f"Reason classifier head saved to: {model_save_path}")
220
+
221
+ return metrics
222
+
223
+ if __name__ == "__main__":
224
+ main()
classifier/train.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from classifier.utils import CHECKPOINT_PATH, DATETIME_FORMAT, get_models, CATEGORIES, DEVICE, CLASSIFIER_NAME
2
+ from classifier.config import HF_TOKEN
3
+ from huggingface_hub import HfApi
4
+ from jinja2 import Template
5
+
6
+ import argparse
7
+ from datetime import datetime
8
+ import datasets as ds
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import os
12
+ import pandas as pd
13
+ import torch
14
+ from torch.utils.data import DataLoader
15
+
16
+ def even_split(prefix: str, target: int, splits: int, total: int) -> str:
17
+ result = ""
18
+ target_amount_per_split = int(target / splits)
19
+ total_amount_per_split = int(total / splits)
20
+
21
+ for i in range(splits):
22
+ left = total_amount_per_split*i
23
+ right = left + target_amount_per_split
24
+ result += f"{prefix}[{int(left)}:{int(right)}]"
25
+
26
+ if i != splits - 1:
27
+ result += "+"
28
+
29
+ return result
30
+
31
+ def get_model_train_test():
32
+ # Login using e.g. `huggingface-cli login` to access this dataset
33
+
34
+ def add_static_label(row, column_name, label):
35
+ row[column_name] = label
36
+ return row
37
+
38
+ # Miriad
39
+ train_split = even_split("train", 50000, 100, 4470000)
40
+ miriad = ds.load_dataset("tomaarsen/miriad-4.4M-split", split={"train":train_split, "test": "test", "validation": "eval"})
41
+ miriad = miriad.rename_column("question", "text")
42
+ miriad = miriad.remove_columns("passage_text")
43
+ miriad = miriad.map(add_static_label, fn_kwargs={"column_name": "label", "label": "medical"})
44
+ # print(miriad)
45
+
46
+ # Insurance
47
+ train_split = even_split("train", 5000, 20, 21300)
48
+ insurance = ds.load_dataset("deccan-ai/insuranceQA-v2", split={"train":train_split, "test":"test", "validation":"validation"})
49
+ insurance = insurance.rename_column("input", "text")
50
+ insurance = insurance.remove_columns(["output"])
51
+ insurance = insurance.map(add_static_label, fn_kwargs={"column_name": "label", "label": "insurance"})
52
+ # print(insurance)
53
+
54
+ # Interleave datasets (mix the datasets into one randomly)
55
+ train = ds.interleave_datasets([miriad["train"], insurance["train"]], stopping_strategy="all_exhausted")
56
+ _ , unique_indices = np.unique(train["text"], return_index=True, axis=0)
57
+ train = train.select(unique_indices.tolist())
58
+ test = ds.interleave_datasets([miriad["test"], insurance["test"]], stopping_strategy="all_exhausted")
59
+ _ , unique_indices = np.unique(test["text"], return_index=True, axis=0)
60
+ test = test.select(unique_indices.tolist())
61
+ validation = ds.interleave_datasets([miriad["validation"], insurance["validation"]], stopping_strategy="all_exhausted")
62
+ _ , unique_indices = np.unique(validation["text"], return_index=True, axis=0)
63
+ validation = validation.select(unique_indices.tolist())
64
+
65
+ print(f"train: {len(train)}, validation: {len(validation)}, test: {len(test)}")
66
+
67
+ # Get models
68
+ embedding_model, classifier = get_models()
69
+
70
+ return embedding_model, classifier, train, test, validation, CATEGORIES
71
+
72
+ def test_loop(dataloader, model, loss_fn):
73
+ # Set the model to evaluation mode - important for batch normalization and dropout layers
74
+ # Unnecessary in this situation but added for best practices
75
+ model.eval()
76
+ size = len(dataloader.dataset)
77
+ num_batches = len(dataloader)
78
+ test_loss, correct = 0, 0
79
+
80
+ # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
81
+ # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
82
+ with torch.no_grad():
83
+ for batch in dataloader:
84
+ pred = model(batch)['logits']
85
+ test_loss += loss_fn(pred, batch['label']).item()
86
+ correct += (pred.argmax(1) == batch['label']).type(torch.float).sum().item()
87
+
88
+ avg_loss = test_loss / num_batches
89
+ accuracy = correct / size
90
+
91
+ return avg_loss, accuracy
92
+
93
+ def train_loop(dataloader, model, loss_fn, optimizer, batch_size = 64, epochs = 10):
94
+ size = len(dataloader.dataset)
95
+ total_loss = 0
96
+ batch_losses = []
97
+
98
+ # Set models to training mode
99
+ model.train()
100
+
101
+ for iteration, batch in enumerate(dataloader):
102
+ # --- 1. Zero Gradients ---
103
+ # Only zero gradients for the parameters you want to update (the classifier head)
104
+ optimizer.zero_grad()
105
+
106
+ # --- 3. Forward Pass: Embeddings -> Logits ---
107
+ # The classifier head takes the embeddings from the body
108
+ pred = model(batch)['logits']
109
+
110
+ # --- 4. Calculate Loss ---
111
+ loss = loss_fn(pred, batch['label'])
112
+
113
+ # --- 5. Backward Pass & Update ---
114
+ loss.backward()
115
+ optimizer.step()
116
+
117
+ cur_loss = loss.item()
118
+ batch_losses.append(cur_loss)
119
+ total_loss += cur_loss
120
+
121
+ if iteration % 100 == 0:
122
+ current = iteration * batch_size + len(batch['label'])
123
+ print(f"loss: {cur_loss:>7f} [{current:>5d}/{size:>5d}]")
124
+
125
+ return total_loss, batch_losses
126
+
127
+ def generate_model_card(save_dir: str, accuracy: float, loss: float, epoch: int):
128
+ with open("classifier/modelcard_template.md", "r") as f:
129
+ template_content = f.read()
130
+
131
+ template = Template(template_content)
132
+
133
+ card_content = template.render(
134
+ model_id=CLASSIFIER_NAME,
135
+ model_summary="A simple medical query triage classifier.",
136
+ model_description="This model classifies queries into 'medical' or 'insurance' categories. It uses EmbeddingGemma-300M as a backbone.",
137
+ developers="David Gray",
138
+ model_type="Text Classification",
139
+ language="en",
140
+ license="mit",
141
+ base_model="sentence-transformers/embeddinggemma-300m-medical",
142
+ repo=f"https://huggingface.co/{CLASSIFIER_NAME}",
143
+ results_summary=f"Epoch: {epoch+1}\nValidation Accuracy: {accuracy*100:.2f}%\nValidation Loss: {loss:.4f}",
144
+ training_data="Miriad (medical) and InsuranceQA (insurance) datasets.",
145
+ testing_metrics="Accuracy, Loss",
146
+ results=f"Accuracy: {accuracy:.4f}, Loss: {loss:.4f}"
147
+ )
148
+
149
+ with open(f"{save_dir}/README.md", "w") as f:
150
+ f.write(card_content)
151
+
152
+ def push_model_card(save_dir: str, repo_id: str, token: str = None):
153
+ api = HfApi(token=token)
154
+ api.upload_file(
155
+ path_or_fileobj=f"{save_dir}/README.md",
156
+ path_in_repo="README.md",
157
+ repo_id=repo_id,
158
+ repo_type="model"
159
+ )
160
+
161
+ def label_to_int(embedding_model, label_names: list):
162
+ """Creates a dictionary mapping label strings to their integer IDs."""
163
+ label_map = {name: i for i, name in enumerate(label_names)}
164
+
165
+ def collate_fn(batch):
166
+ # 1. Extract texts and labels from the batch (list of dictionaries)
167
+ texts = [item['text'] for item in batch]
168
+ labels = [item['label'] for item in batch]
169
+
170
+ # 2. Tokenize the texts using the embedding model's tokenizer
171
+ # The tokenizer is attached to the embedding_model
172
+ with torch.no_grad():
173
+ tokenized_text = embedding_model.encode(
174
+ texts,
175
+ convert_to_tensor=True,
176
+ device=DEVICE
177
+ ).clone().detach()
178
+
179
+ # 3. Convert string labels to integers
180
+ int_labels = [label_map[l] for l in labels]
181
+ tokenized_labels = torch.tensor(int_labels, dtype=torch.long)
182
+
183
+ # 4. Add the labels as a PyTorch tensor
184
+ tokenized_batch = {'sentence_embedding': tokenized_text.to(DEVICE), 'label': tokenized_labels.to(DEVICE)}
185
+
186
+ return tokenized_batch
187
+
188
+ return collate_fn
189
+
190
+ def train(push_to_hub: bool = False):
191
+ start_datetime = datetime.now()
192
+
193
+ save_dir = f'{CHECKPOINT_PATH}/{start_datetime.strftime(DATETIME_FORMAT)}'
194
+ os.makedirs(save_dir, exist_ok=True)
195
+
196
+ embedding_model, model, train_ds, test_ds, validation_ds, labels = get_model_train_test()
197
+ batch_size = 64
198
+ custom_collate_fn = label_to_int(embedding_model, labels)
199
+
200
+ train_dataloader = DataLoader(
201
+ train_ds,
202
+ batch_size=batch_size,
203
+ shuffle=True,
204
+ collate_fn=custom_collate_fn
205
+ )
206
+ test_dataloader = DataLoader(
207
+ test_ds,
208
+ batch_size=batch_size,
209
+ shuffle=True,
210
+ collate_fn=custom_collate_fn
211
+ )
212
+ validation_dataloader = DataLoader(
213
+ validation_ds,
214
+ batch_size=batch_size,
215
+ shuffle=True,
216
+ collate_fn=custom_collate_fn
217
+ )
218
+
219
+ loss_fn = model.get_loss_fn()
220
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
221
+ save_per_epoch = 1
222
+ epochs = 1
223
+ patience = 1
224
+ min_val_loss = float('inf')
225
+ patience_counter = 0
226
+ history = {
227
+ 'train_loss_epoch': [],
228
+ 'train_loss_batch': [],
229
+ 'validation_accuracy': [],
230
+ 'validation_loss_epoch': [],
231
+ 'test_accuracy': [],
232
+ 'test_loss': []
233
+ }
234
+
235
+ for epoch in range(epochs):
236
+ print(f"Epoch {epoch+1}:\n-------------------------------")
237
+
238
+ # Train
239
+ total_loss, batch_losses = train_loop(train_dataloader, model, loss_fn, optimizer)
240
+ avg_epoch_loss = total_loss / len(train_dataloader)
241
+ history['train_loss_epoch'].append(avg_epoch_loss)
242
+ history['train_loss_batch'].extend(batch_losses)
243
+
244
+ summary = f"Epoch {epoch+1}:"
245
+
246
+ # Validate
247
+ val_loss_avg, val_accuracy = test_loop(validation_dataloader, model, loss_fn)
248
+ history['validation_accuracy'].append(val_accuracy)
249
+ history['validation_loss_epoch'].append(val_loss_avg)
250
+
251
+ summary += f" - loss: {avg_epoch_loss}\n"
252
+ summary += f" - training loss: {avg_epoch_loss}\n"
253
+ summary += f" - validation loss: {val_loss_avg:>8f}\n"
254
+ summary += f" - validation accuracy: {(100*val_accuracy):>0.1f}%\n"
255
+
256
+ # Save checkpoint
257
+ if epoch % save_per_epoch == 0:
258
+ # Save model
259
+ model.save_pretrained(save_dir)
260
+
261
+ # Generate and push model card
262
+ # generate_model_card(save_dir, val_accuracy, val_loss_avg, epoch)
263
+ # push_model_card(save_dir, CLASSIFIER_NAME, token=HF_TOKEN)
264
+
265
+ summary += f" -- {save_dir}\n"
266
+
267
+ history_df = pd.DataFrame.from_dict(history, orient='index').transpose()
268
+ history_df.to_csv(f"{save_dir}/history.csv", index=False)
269
+
270
+ # Push model to Hugging Face
271
+ if push_to_hub:
272
+ model.push_to_hub(CLASSIFIER_NAME, token=HF_TOKEN)
273
+ else:
274
+ summary += "\n"
275
+
276
+ print(summary)
277
+
278
+ if val_loss_avg < min_val_loss:
279
+ min_val_loss = val_loss_avg
280
+ patience_counter = 0
281
+ else:
282
+ patience_counter += 1
283
+ if patience_counter >= patience:
284
+ print("Early stopping triggered due to no improvement in validation loss.")
285
+ break
286
+
287
+ # Evaluate on test dataset
288
+ test_loss_avg, test_accuracy = test_loop(test_dataloader, model, loss_fn)
289
+ history['test_accuracy'].append(test_accuracy)
290
+ history['test_loss'].append(test_loss_avg)
291
+ print(f"Test: Accuracy: {(100*test_accuracy):>0.1f}%, Avg loss: {test_loss_avg:>8f}")
292
+
293
+ # Save the final model
294
+ model.save_pretrained(save_dir)
295
+
296
+ # generate_model_card(save_dir, test_accuracy, test_loss_avg, epochs-1)
297
+ # push_model_card(save_dir, CLASSIFIER_NAME, token=HF_TOKEN)
298
+
299
+ # Save loss history
300
+ history_df = pd.DataFrame.from_dict(history, orient='index').transpose()
301
+ history_df.to_csv(f"{save_dir}/history.csv", index=False)
302
+
303
+ # Plot training loss per batch
304
+ fig, ax = plt.subplots()
305
+ ax.plot(history['train_loss_batch'])
306
+ ax.set_title('Training Loss per Batch')
307
+ ax.set_xlabel('Batch')
308
+ ax.set_ylabel('Loss')
309
+ fig.savefig(f"{save_dir}/loss.png")
310
+
311
+ if push_to_hub:
312
+ model.push_to_hub(CLASSIFIER_NAME, token=HF_TOKEN)
313
+
314
+ if __name__ == "__main__":
315
+ ap = argparse.ArgumentParser(
316
+ description="Train a classifier for triaging health queries"
317
+ )
318
+ ap.add_argument(
319
+ "--push", action="store_true",
320
+ help="Push model to Hugging Face"
321
+ )
322
+ args = ap.parse_args()
323
+
324
+ train(push_to_hub=args.push)
classifier/utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for Healthcare Classification System
3
+
4
+ This module contains shared constants and utilities for the healthcare
5
+ classification system.
6
+ """
7
+
8
+ from classifier.head import ClassifierHead
9
+
10
+ from classifier.config import load_env
11
+
12
+ import os
13
+ from sentence_transformers import SentenceTransformer
14
+ import torch
15
+ from datetime import datetime
16
+ from pathlib import Path
17
+
18
+ # Load environment variables (including HF_TOKEN)
19
+ load_env()
20
+
21
+ MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
22
+ CLASSIFIER_NAME = "davidgray/health-query-triage"
23
+ CATEGORIES: list[str] = ["medical", "insurance"]
24
+
25
+ # Model and training configuration
26
+ MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
27
+ CHECKPOINT_PATH = "classifier/checkpoints"
28
+ DATETIME_FORMAT = "%Y%m%d_%H%M%S"
29
+
30
+ # Device configuration - use David's newer approach with fallback
31
+ try:
32
+ DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
33
+ except AttributeError:
34
+ # Fallback for older PyTorch versions
35
+ if torch.backends.mps.is_available():
36
+ DEVICE = torch.device("mps")
37
+ elif torch.cuda.is_available():
38
+ DEVICE = torch.device("cuda")
39
+ else:
40
+ DEVICE = torch.device("cpu")
41
+
42
+ print(f"Using {DEVICE} device")
43
+
44
+ def get_models(model_id: str | None = None, num_labels: int = len(CATEGORIES)) -> tuple[SentenceTransformer, ClassifierHead]:
45
+ """
46
+ Loads embeddinggemma-300m-medical model and initializes the classification head.
47
+
48
+ Returns:
49
+ tuple: (embedding_model, classifier_head)
50
+ """
51
+ try:
52
+ model_body = SentenceTransformer(
53
+ MODEL_NAME,
54
+ prompts={
55
+ 'classification': 'task: classification | query: ',
56
+ 'retrieval (query)': 'task: search result | query: ',
57
+ 'retrieval (document)': 'title: {title | "none"} | text: ',
58
+ },
59
+ default_prompt_name='classification',
60
+ )
61
+
62
+ if model_id:
63
+ model_head = ClassifierHead.from_pretrained(model_id)
64
+ else:
65
+ model_head = ClassifierHead(num_labels)
66
+
67
+ except Exception as e:
68
+ print(f"Error loading model {MODEL_NAME}: {e}")
69
+ print("Please ensure you have an internet connection and the transformers library installed.")
70
+ raise RuntimeError("Failed to load the embedding model.")
71
+
72
+ return model_body.to(DEVICE), model_head.to(DEVICE)
73
+
74
+ def get_latest_checkpoint(checkpoint_path: str):
75
+ return os.path.join(checkpoint_path, sorted(os.listdir(checkpoint_path))[-1])
cli/healthcare_classifier_cli.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ End-to-End Healthcare Classification CLI
3
+
4
+ This provides a complete classification pipeline:
5
+ 1. First classifies as "medical" or "insurance"
6
+ 2. If medical, applies reason classification for detailed categorization
7
+
8
+ IMPORTANT: Activate virtual environment first!
9
+ Usage:
10
+ source .venv/bin/activate
11
+ python cli/healthcare_classifier_cli.py --interactive
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ import sys
17
+ from pathlib import Path
18
+
19
+ # Add project root to path
20
+ REPO_ROOT = Path(__file__).resolve().parents[1]
21
+ if str(REPO_ROOT) not in sys.path:
22
+ sys.path.insert(0, str(REPO_ROOT))
23
+
24
+ def classify_healthcare_query(query: str):
25
+ """
26
+ Complete healthcare query classification pipeline.
27
+
28
+ Step 1: Medical vs Insurance classification
29
+ Step 2: If medical, apply reason classification
30
+ """
31
+
32
+ print(f"Query: {query}")
33
+ print("=" * 60)
34
+
35
+ try:
36
+ # Add classifier to path
37
+ sys.path.append('classifier')
38
+
39
+ # Step 1: Medical vs Insurance Classification
40
+ print("🔍 Step 1: Medical vs Insurance Classification")
41
+ print("-" * 40)
42
+
43
+ from infer import predict_query
44
+ from utils import get_models
45
+
46
+ # Load medical/insurance classifier
47
+ embedding_model, classifier_head = get_models()
48
+
49
+ # Get medical vs insurance prediction
50
+ result = predict_query([query], embedding_model, classifier_head)
51
+
52
+ primary_category = result['prediction'][0]
53
+ confidence = result['confidence']
54
+ if isinstance(confidence, list):
55
+ confidence = confidence[0]
56
+
57
+ print(f"Primary Classification: {primary_category.upper()}")
58
+ print(f"Confidence: {confidence:.4f}")
59
+
60
+ # Show probabilities
61
+ probabilities = result['probabilities']
62
+ if isinstance(probabilities[0], list):
63
+ probabilities = probabilities[0]
64
+
65
+ print("Probabilities:")
66
+ from utils import CATEGORIES
67
+ for i, category in enumerate(CATEGORIES):
68
+ print(f" {category}: {probabilities[i]:.4f}")
69
+
70
+ # Step 2: If medical, apply reason classification
71
+ if primary_category.lower() == 'medical':
72
+ print(f"\n🏥 Step 2: Medical Reason Classification")
73
+ print("-" * 40)
74
+
75
+ try:
76
+ from classifier.reason.infer_reason import predict_single_reason
77
+
78
+ reason_result = predict_single_reason(query)
79
+
80
+ print(f"Medical Reason: {reason_result['category']}")
81
+ print(f"Reason Confidence: {reason_result['confidence']:.4f}")
82
+
83
+ print("Reason Probabilities:")
84
+ sorted_probs = sorted(reason_result['probabilities'].items(),
85
+ key=lambda x: x[1], reverse=True)
86
+ for category, prob in sorted_probs:
87
+ print(f" {category}: {prob:.4f}")
88
+
89
+ # Final routing decision
90
+ print(f"\n🎯 Final Routing Decision")
91
+ print("-" * 25)
92
+ print(f"Route to: {reason_result['category']} Department")
93
+ print(f"Overall confidence: Medical ({confidence:.3f}) → {reason_result['category']} ({reason_result['confidence']:.3f})")
94
+
95
+ return {
96
+ 'primary_classification': primary_category,
97
+ 'primary_confidence': confidence,
98
+ 'reason_classification': reason_result['category'],
99
+ 'reason_confidence': reason_result['confidence'],
100
+ 'routing': f"{reason_result['category']} Department"
101
+ }
102
+
103
+ except Exception as e:
104
+ print(f"⚠️ Reason classification failed: {e}")
105
+ print("Note: Make sure reason classifier is trained")
106
+ print(f"Routing to: General Medical Department")
107
+
108
+ return {
109
+ 'primary_classification': primary_category,
110
+ 'primary_confidence': confidence,
111
+ 'reason_classification': 'GENERAL_MEDICAL',
112
+ 'reason_confidence': 0.0,
113
+ 'routing': 'General Medical Department'
114
+ }
115
+
116
+ else:
117
+ # Insurance query
118
+ print(f"\n💳 Final Routing Decision")
119
+ print("-" * 25)
120
+ print(f"Route to: Insurance Department")
121
+ print(f"Confidence: {confidence:.3f}")
122
+
123
+ return {
124
+ 'primary_classification': primary_category,
125
+ 'primary_confidence': confidence,
126
+ 'reason_classification': None,
127
+ 'reason_confidence': None,
128
+ 'routing': 'Insurance Department'
129
+ }
130
+
131
+ except Exception as e:
132
+ print(f"❌ Classification failed: {e}")
133
+ if "No module named 'torch'" in str(e):
134
+ print("\n🔧 SOLUTION:")
135
+ print("You need to activate the virtual environment first!")
136
+ print("Run these commands:")
137
+ print(" source .venv/bin/activate")
138
+ print(" python cli/healthcare_classifier_cli.py --interactive")
139
+ else:
140
+ print("Note: Make sure models are trained and available")
141
+ return None
142
+
143
+ def classify_batch_queries(queries_file: str, output_file: str = None):
144
+ """Process multiple queries through the complete pipeline."""
145
+
146
+ try:
147
+ # Read queries
148
+ with open(queries_file, 'r') as f:
149
+ if queries_file.endswith('.json'):
150
+ data = json.load(f)
151
+ if isinstance(data, list):
152
+ queries = data
153
+ else:
154
+ queries = data.get('queries', [])
155
+ else:
156
+ queries = [line.strip() for line in f if line.strip()]
157
+
158
+ print(f"Processing {len(queries)} queries through complete pipeline...")
159
+ print("=" * 60)
160
+
161
+ results = []
162
+ for i, query in enumerate(queries, 1):
163
+ print(f"\n📋 Query {i}/{len(queries)}")
164
+ result = classify_healthcare_query(query)
165
+ if result:
166
+ result['query'] = query
167
+ results.append(result)
168
+ print()
169
+
170
+ # Save results if output file specified
171
+ if output_file:
172
+ output_data = {
173
+ 'queries': queries,
174
+ 'predictions': results,
175
+ 'summary': {
176
+ 'total_queries': len(queries),
177
+ 'medical_queries': len([r for r in results if r['primary_classification'].lower() == 'medical']),
178
+ 'insurance_queries': len([r for r in results if r['primary_classification'].lower() == 'insurance']),
179
+ 'reason_categories': {}
180
+ }
181
+ }
182
+
183
+ # Count reason categories
184
+ for result in results:
185
+ if result['reason_classification']:
186
+ cat = result['reason_classification']
187
+ output_data['summary']['reason_categories'][cat] = output_data['summary']['reason_categories'].get(cat, 0) + 1
188
+
189
+ with open(output_file, 'w') as f:
190
+ json.dump(output_data, f, indent=2)
191
+
192
+ print(f"📄 Results saved to {output_file}")
193
+
194
+ # Show summary
195
+ medical_count = len([r for r in results if r['primary_classification'].lower() == 'medical'])
196
+ insurance_count = len([r for r in results if r['primary_classification'].lower() == 'insurance'])
197
+
198
+ print(f"\n📊 Summary:")
199
+ print(f" Medical queries: {medical_count} ({medical_count/len(results)*100:.1f}%)")
200
+ print(f" Insurance queries: {insurance_count} ({insurance_count/len(results)*100:.1f}%)")
201
+
202
+ if medical_count > 0:
203
+ reason_counts = {}
204
+ for result in results:
205
+ if result['reason_classification']:
206
+ cat = result['reason_classification']
207
+ reason_counts[cat] = reason_counts.get(cat, 0) + 1
208
+
209
+ print(f"\n Medical reason breakdown:")
210
+ for category, count in sorted(reason_counts.items()):
211
+ percentage = (count / medical_count) * 100
212
+ print(f" {category}: {count} queries ({percentage:.1f}%)")
213
+
214
+ except Exception as e:
215
+ print(f"❌ Error processing batch queries: {e}")
216
+ return False
217
+
218
+ return True
219
+
220
+ def interactive_mode():
221
+ """Interactive mode for complete healthcare classification."""
222
+
223
+ print("🏥 Complete Healthcare Classification System")
224
+ print("=" * 50)
225
+ print("This system provides end-to-end classification:")
226
+ print(" 1️⃣ Medical vs Insurance classification")
227
+ print(" 2️⃣ Medical reason classification (if medical)")
228
+ print(" 3️⃣ Final routing decision")
229
+ print()
230
+ print("Enter healthcare queries to classify (type 'quit' to exit)")
231
+ print()
232
+ print("Example queries to try:")
233
+ print(" Medical: 'I have heel pain when I walk'")
234
+ print(" Medical: 'I need routine foot care'")
235
+ print(" Medical: 'I sprained my ankle'")
236
+ print(" Insurance: 'My insurance claim was denied'")
237
+ print(" Insurance: 'What does my insurance cover?'")
238
+ print()
239
+
240
+ while True:
241
+ try:
242
+ user_input = input("🔍 Enter query >>> ").strip()
243
+
244
+ if user_input.lower() == 'quit':
245
+ print("👋 Goodbye!")
246
+ break
247
+
248
+ if user_input:
249
+ classify_healthcare_query(user_input)
250
+ print("\n" + "="*60)
251
+
252
+ except KeyboardInterrupt:
253
+ print("\n👋 Goodbye!")
254
+ break
255
+ except Exception as e:
256
+ print(f"❌ Error: {e}")
257
+ print()
258
+
259
+ def main():
260
+ parser = argparse.ArgumentParser(
261
+ description='Complete Healthcare Classification CLI',
262
+ formatter_class=argparse.RawDescriptionHelpFormatter,
263
+ epilog="""
264
+ Examples:
265
+ # Interactive mode (recommended)
266
+ python cli/healthcare_classifier_cli.py --interactive
267
+
268
+ # Classify a single query
269
+ python cli/healthcare_classifier_cli.py "I have heel pain"
270
+
271
+ # Batch process queries from file
272
+ python cli/healthcare_classifier_cli.py --batch queries.txt --output results.json
273
+
274
+ Pipeline:
275
+ Query → Medical/Insurance → (if Medical) → Reason Classification → Routing
276
+ """
277
+ )
278
+
279
+ parser.add_argument('query', nargs='?', help='Healthcare query to classify')
280
+ parser.add_argument('--batch', type=str, help='File containing queries to process')
281
+ parser.add_argument('--output', type=str, help='Output file for batch results')
282
+ parser.add_argument('--interactive', action='store_true',
283
+ help='Start interactive mode (recommended)')
284
+
285
+ args = parser.parse_args()
286
+
287
+ # Interactive mode
288
+ if args.interactive:
289
+ interactive_mode()
290
+ return 0
291
+
292
+ # Batch processing
293
+ if args.batch:
294
+ if not Path(args.batch).exists():
295
+ print(f"❌ Error: Batch file does not exist: {args.batch}")
296
+ return 1
297
+
298
+ success = classify_batch_queries(args.batch, args.output)
299
+ return 0 if success else 1
300
+
301
+ # Single query processing
302
+ if args.query:
303
+ result = classify_healthcare_query(args.query)
304
+ return 0 if result else 1
305
+
306
+ # No arguments provided - show help and suggest interactive mode
307
+ print("🏥 Complete Healthcare Classification System")
308
+ print("=" * 45)
309
+ print("IMPORTANT: Activate virtual environment first!")
310
+ print(" source .venv/bin/activate")
311
+ print(" python cli/healthcare_classifier_cli.py --interactive")
312
+ print()
313
+ parser.print_help()
314
+ return 1
315
+
316
+ if __name__ == "__main__":
317
+ sys.exit(main())
cli/reason_classifier_cli.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI Interface for Healthcare Reason Classification
3
+
4
+ This provides a command-line interface for testing and using the
5
+ healthcare reason classifier system with real healthcare data.
6
+ """
7
+
8
+ import argparse
9
+ import json
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ # Add project root to path
14
+ REPO_ROOT = Path(__file__).resolve().parents[1]
15
+ if str(REPO_ROOT) not in sys.path:
16
+ sys.path.insert(0, str(REPO_ROOT))
17
+
18
+ def classify_single_query(query: str):
19
+ """Classify a single healthcare reason query and display results."""
20
+
21
+ print(f"Query: {query}")
22
+ print("-" * 50)
23
+
24
+ try:
25
+ # Import the reason inference module
26
+ sys.path.append('classifier')
27
+ from classifier.reason.infer_reason import predict_single_reason
28
+
29
+ # Get prediction
30
+ result = predict_single_reason(query)
31
+
32
+ print(f"Primary Classification: {result['category']}")
33
+ print(f"Confidence: {result['confidence']:.4f}")
34
+
35
+ # Show all category probabilities
36
+ print(f"\nAll Category Probabilities:")
37
+
38
+ # Sort by probability
39
+ sorted_probs = sorted(result['probabilities'].items(),
40
+ key=lambda x: x[1], reverse=True)
41
+
42
+ for category, prob in sorted_probs:
43
+ print(f" {category}: {prob:.4f}")
44
+
45
+ except Exception as e:
46
+ print(f"Error: {e}")
47
+ print("Note: Make sure the reason classifier is trained")
48
+ return False
49
+
50
+ return True
51
+
52
+ def classify_batch_queries(queries_file: str, output_file: str = None):
53
+ """Classify multiple queries from a file."""
54
+
55
+ try:
56
+ # Read queries
57
+ with open(queries_file, 'r') as f:
58
+ if queries_file.endswith('.json'):
59
+ data = json.load(f)
60
+ if isinstance(data, list):
61
+ queries = data
62
+ else:
63
+ queries = data.get('queries', [])
64
+ else:
65
+ queries = [line.strip() for line in f if line.strip()]
66
+
67
+ print(f"Processing {len(queries)} healthcare reason queries...")
68
+
69
+ # Import the reason inference module
70
+ sys.path.append('classifier')
71
+ from classifier.reason.infer_reason import predict_single_reason
72
+
73
+ results = []
74
+ for i, query in enumerate(queries, 1):
75
+ print(f"\n{i}. Query: {query}")
76
+
77
+ result = predict_single_reason(query)
78
+ results.append(result)
79
+
80
+ print(f" Category: {result['category']} (confidence: {result['confidence']:.3f})")
81
+
82
+ # Save results if output file specified
83
+ if output_file:
84
+ output_data = {
85
+ 'queries': queries,
86
+ 'predictions': results,
87
+ 'summary': {
88
+ 'total_queries': len(queries),
89
+ 'categories': {}
90
+ }
91
+ }
92
+
93
+ # Count categories
94
+ for result in results:
95
+ cat = result['category']
96
+ output_data['summary']['categories'][cat] = output_data['summary']['categories'].get(cat, 0) + 1
97
+
98
+ with open(output_file, 'w') as f:
99
+ json.dump(output_data, f, indent=2)
100
+
101
+ print(f"\nResults saved to {output_file}")
102
+
103
+ # Show summary
104
+ category_counts = {}
105
+ for result in results:
106
+ cat = result['category']
107
+ category_counts[cat] = category_counts.get(cat, 0) + 1
108
+
109
+ print(f"\nSummary:")
110
+ for category, count in sorted(category_counts.items()):
111
+ percentage = (count / len(queries)) * 100
112
+ print(f" {category}: {count} queries ({percentage:.1f}%)")
113
+
114
+ except Exception as e:
115
+ print(f"Error processing batch queries: {e}")
116
+ return False
117
+
118
+ return True
119
+
120
+ def interactive_mode():
121
+ """Interactive mode for testing healthcare reason queries."""
122
+
123
+ print("Healthcare Reason Classifier - Interactive Mode")
124
+ print("=" * 50)
125
+ print("Enter healthcare reason queries to classify (type 'quit' to exit)")
126
+ print()
127
+ print("Example queries to try:")
128
+ print(" • 'I have heel pain when I walk'")
129
+ print(" • 'My toenail is ingrown and infected'")
130
+ print(" • 'I need routine foot care'")
131
+ print(" • 'I sprained my ankle playing basketball'")
132
+ print(" • 'I have plantar fasciitis'")
133
+ print(" • 'I need a cortisone injection'")
134
+ print()
135
+
136
+ while True:
137
+ try:
138
+ user_input = input(">>> ").strip()
139
+
140
+ if user_input.lower() == 'quit':
141
+ break
142
+
143
+ if user_input:
144
+ classify_single_query(user_input)
145
+ print()
146
+
147
+ except KeyboardInterrupt:
148
+ break
149
+ except Exception as e:
150
+ print(f"Error: {e}")
151
+ print()
152
+
153
+ def main():
154
+ parser = argparse.ArgumentParser(
155
+ description='Healthcare Reason Classification CLI',
156
+ formatter_class=argparse.RawDescriptionHelpFormatter,
157
+ epilog="""
158
+ Examples:
159
+ # Classify a single healthcare reason query
160
+ python cli/reason_classifier_cli_new.py "I have heel pain"
161
+
162
+ # Batch process queries from file
163
+ python cli/reason_classifier_cli_new.py --batch reason_queries.txt --output results.json
164
+
165
+ # Interactive mode
166
+ python cli/reason_classifier_cli_new.py --interactive
167
+ """
168
+ )
169
+
170
+ parser.add_argument('query', nargs='?', help='Healthcare reason query to classify')
171
+ parser.add_argument('--batch', type=str, help='File containing queries to process')
172
+ parser.add_argument('--output', type=str, help='Output file for batch results')
173
+ parser.add_argument('--interactive', action='store_true',
174
+ help='Start interactive mode')
175
+
176
+ args = parser.parse_args()
177
+
178
+ # Interactive mode
179
+ if args.interactive:
180
+ interactive_mode()
181
+ return 0
182
+
183
+ # Batch processing
184
+ if args.batch:
185
+ if not Path(args.batch).exists():
186
+ print(f"Error: Batch file does not exist: {args.batch}")
187
+ return 1
188
+
189
+ success = classify_batch_queries(args.batch, args.output)
190
+ return 0 if success else 1
191
+
192
+ # Single query processing
193
+ if args.query:
194
+ success = classify_single_query(args.query)
195
+ return 0 if success else 1
196
+
197
+ # No arguments provided
198
+ parser.print_help()
199
+ return 1
200
+
201
+ if __name__ == "__main__":
202
+ sys.exit(main())
config.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Dict, List
4
+ import torch
5
+ from pydantic_settings import BaseSettings
6
+
7
+ class Settings(BaseSettings):
8
+ # Model Configuration
9
+ MODEL_NAME: str = "sentence-transformers/embeddinggemma-300m-medical"
10
+ CLASSIFIER_NAME: str = "davidgray/health-query-triage"
11
+ CATEGORIES: List[str] = ["medical", "insurance"]
12
+
13
+ # Paths
14
+ CHECKPOINT_PATH: str = "classifier/checkpoints"
15
+ CACHE_DIR: str = ".cache/embeddings"
16
+
17
+ # Device
18
+ DEVICE: str = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
19
+
20
+ # Corpora Configuration
21
+ CORPORA_CONFIG: Dict[str, dict] = {
22
+ "medical_qa": {"path": "data/corpora/medical_qa.jsonl",
23
+ "text_fields": ["question", "answer", "title"]},
24
+ "miriad": {"path": "data/corpora/miriad_text.jsonl",
25
+ "text_fields": ["question", "answer", "title"]},
26
+ "pubmed": {"path": "data/corpora/pubmed.json",
27
+ "text_fields": ["contents","title"]},
28
+ "unidoc": {"path": "data/corpora/unidoc_qa.jsonl",
29
+ "text_fields": ["question", "answer", "title"]},
30
+ }
31
+
32
+ class Config:
33
+ env_file = ".env"
34
+
35
+ settings = Settings()
environment.yml ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: cs410-group-proj
2
+ channels:
3
+ - defaults
4
+ - conda-forge
5
+ - pytorch
6
+ - huggingface
7
+ - anaconda
8
+ dependencies:
9
+ - blas=1.0=openblas
10
+ - brotli-python=1.0.9=py311h313beb8_9
11
+ - bzip2=1.0.8=hd037594_8
12
+ - ca-certificates=2025.10.5=hbd8a1cb_0
13
+ - certifi=2025.10.5=py311hca03da5_0
14
+ - cffi=2.0.0=py311h3a083c1_0
15
+ - cryptography=44.0.1=py311h8026fc7_0
16
+ - faiss-cpu=1.12.0=py3.11_hcb8d3e5_0_cpu
17
+ - gmp=6.3.0=h313beb8_0
18
+ - gmpy2=2.2.1=py311h5c1b81f_0
19
+ - huggingface_hub=0.29.2=py_0
20
+ - icu=75.1=hfee45f7_0
21
+ - jinja2=3.1.6=py311hca03da5_0
22
+ - libcxx=20.1.8=h8869778_0
23
+ - libexpat=2.7.1=hec049ff_0
24
+ - libfaiss=1.12.0=py3.11_hcb8d3e5_0_cpu
25
+ - libffi=3.4.6=h1da3d7d_1
26
+ - libgfortran=5.0.0=11_3_0_hca03da5_28
27
+ - libgfortran5=11.3.0=h009349e_28
28
+ - libidn2=2.3.4=h80987f9_0
29
+ - liblzma=5.8.1=h39f12f2_2
30
+ - libopenblas=0.3.30=hf2bb037_0
31
+ - libsqlite=3.50.4=h4237e3c_0
32
+ - libunistring=0.9.10=h1a28f6b_0
33
+ - libzlib=1.3.1=h5f15de7_0
34
+ - llvm-openmp=20.1.8=he822017_0
35
+ - maven=3.9.11=hce30654_0
36
+ - mpc=1.3.1=h80987f9_0
37
+ - mpfr=4.2.1=h80987f9_0
38
+ - mpmath=1.3.0=py311hca03da5_0
39
+ - ncurses=6.5=h5e97a16_3
40
+ - networkx=3.5=py311hca03da5_0
41
+ - nomkl=3.0=0
42
+ - numpy=1.26.4=py311h901140f_1
43
+ - numpy-base=1.26.4=py311hae06d03_1
44
+ - openblas=0.3.30=hb03180a_0
45
+ - openblas-devel=0.3.30=h1465027_0
46
+ - openjdk=21.0.8=h55d13f6_0
47
+ - openssl=3.5.4=h5503f6c_0
48
+ - packaging=25.0=py311hca03da5_0
49
+ - pip=25.2=pyh8b19718_0
50
+ - pycparser=2.23=py311hca03da5_0
51
+ - pyopenssl=25.0.0=py311h9e2d7d8_0
52
+ - pysocks=1.7.1=py311hca03da5_0
53
+ - python=3.11.14=hec0b533_1_cpython
54
+ - python_abi=3.11=1_cp311
55
+ - pytorch=2.2.2=py3.11_0
56
+ - readline=8.2=h1d1bf99_2
57
+ - requests=2.32.5=py311hca03da5_0
58
+ - setuptools=80.9.0=pyhff2d567_0
59
+ - sympy=1.14.0=py311hca03da5_0
60
+ - tk=8.6.13=h892fb3f_2
61
+ - tqdm=4.67.1=py311hb6e6a13_0
62
+ - typing-extensions=4.15.0=py311hca03da5_0
63
+ - typing_extensions=4.15.0=py311hca03da5_0
64
+ - wget=1.24.5=h3e2b118_0
65
+ - wheel=0.45.1=pyhd8ed1ab_1
66
+ - yaml=0.2.5=h1a28f6b_0
67
+ - zlib=1.3.1=h5f15de7_0
68
+ - pip:
69
+ - aiohappyeyeballs==2.6.1
70
+ - aiohttp==3.13.1
71
+ - aiosignal==1.4.0
72
+ - annotated-types==0.7.0
73
+ - anyio==4.11.0
74
+ - attrs==25.4.0
75
+ - blinker==1.9.0
76
+ - blis==1.3.0
77
+ - catalogue==2.0.10
78
+ - charset-normalizer==3.4.4
79
+ - click==8.3.0
80
+ - cloudpathlib==0.23.0
81
+ - coloredlogs==15.0.1
82
+ - confection==0.1.5
83
+ - cymem==2.0.11
84
+ - cython==3.1.4
85
+ - datasets==2.13.2
86
+ - dill==0.3.6
87
+ - distro==1.9.0
88
+ - fastapi==0.119.0
89
+ - filelock==3.20.0
90
+ - flask==3.1.2
91
+ - flatbuffers==25.9.23
92
+ - frozenlist==1.8.0
93
+ - fsspec==2025.9.0
94
+ - h11==0.16.0
95
+ - hf-xet==1.1.10
96
+ - httpcore==1.0.9
97
+ - httpx==0.28.1
98
+ - httpx-sse==0.4.3
99
+ - huggingface-hub==0.35.3
100
+ - humanfriendly==10.0
101
+ - idna==3.11
102
+ - itsdangerous==2.2.0
103
+ - jiter==0.11.1
104
+ - joblib==1.5.2
105
+ - jsonschema==4.25.1
106
+ - jsonschema-specifications==2025.9.1
107
+ - langcodes==3.5.0
108
+ - language-data==1.3.0
109
+ - marisa-trie==1.3.1
110
+ - markdown-it-py==4.0.0
111
+ - markupsafe==3.0.3
112
+ - mcp==1.18.0
113
+ - mdurl==0.1.2
114
+ - multidict==6.7.0
115
+ - multiprocess==0.70.14
116
+ - murmurhash==1.0.13
117
+ - onnxruntime==1.23.1
118
+ - openai==2.5.0
119
+ - pandas==2.3.3
120
+ - pillow==12.0.0
121
+ - preshed==3.0.10
122
+ - propcache==0.4.1
123
+ - protobuf==6.33.0
124
+ - pyarrow==11.0.0
125
+ - pybind11==3.0.1
126
+ - pydantic==2.12.3
127
+ - pydantic-core==2.41.4
128
+ - pydantic-settings==2.11.0
129
+ - pygments==2.19.2
130
+ - pyjnius==1.7.0
131
+ - pyserini==1.2.0
132
+ - python-dateutil==2.9.0.post0
133
+ - python-dotenv==1.1.1
134
+ - python-multipart==0.0.20
135
+ - pytz==2025.2
136
+ - pyyaml==6.0.3
137
+ - referencing==0.37.0
138
+ - regex==2025.9.18
139
+ - rich==14.2.0
140
+ - rpds-py==0.27.1
141
+ - safetensors==0.6.2
142
+ - scikit-learn==1.7.2
143
+ - scipy==1.16.2
144
+ - sentencepiece==0.2.1
145
+ - shellingham==1.5.4
146
+ - six==1.17.0
147
+ - smart-open==7.3.1
148
+ - sniffio==1.3.1
149
+ - spacy==3.8.7
150
+ - spacy-legacy==3.0.12
151
+ - spacy-loggers==1.0.5
152
+ - srsly==2.5.1
153
+ - sse-starlette==3.0.2
154
+ - starlette==0.48.0
155
+ - thinc==8.3.6
156
+ - threadpoolctl==3.6.0
157
+ - tiktoken==0.12.0
158
+ - tokenizers==0.22.1
159
+ - torch==2.9.0
160
+ - torchaudio==2.9.0
161
+ - torchvision==0.24.0
162
+ - transformers==4.57.1
163
+ - typer==0.19.2
164
+ - typing-inspection==0.4.2
165
+ - tzdata==2025.2
166
+ - urllib3==2.5.0
167
+ - uvicorn==0.38.0
168
+ - wasabi==1.1.3
169
+ - weasel==0.4.1
170
+ - werkzeug==3.1.3
171
+ - wrapt==2.0.0
172
+ - xxhash==3.6.0
173
+ - yarl==1.22.0
174
+ prefix: /opt/homebrew/Caskroom/miniconda/base/envs/cs410-group-proj
launch_ui.bat ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ REM Medical Q&A Bot - Easy Launcher
3
+ REM Double-click this file to launch the web UI
4
+
5
+ echo ========================================
6
+ echo Medical Q&A Bot - Web Interface
7
+ echo ========================================
8
+ echo.
9
+
10
+ REM Check if virtual environment exists
11
+ if exist ".venv\Scripts\activate.bat" (
12
+ echo Activating virtual environment...
13
+ call .venv\Scripts\activate.bat
14
+ ) else (
15
+ echo No virtual environment found. Using system Python.
16
+ )
17
+
18
+ echo.
19
+ echo Launching Gradio interface...
20
+ echo The web UI will open at: http://127.0.0.1:7860
21
+ echo.
22
+ echo Press Ctrl+C to stop the server
23
+ echo ========================================
24
+ echo.
25
+
26
+ python app.py
27
+
28
+ pause
launch_ui.ps1 ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Medical Q&A Bot - PowerShell Launcher
2
+ # Run this script to launch the web UI
3
+
4
+ Write-Host "========================================" -ForegroundColor Cyan
5
+ Write-Host "Medical Q&A Bot - Web Interface" -ForegroundColor Cyan
6
+ Write-Host "========================================" -ForegroundColor Cyan
7
+ Write-Host ""
8
+
9
+ # Check if virtual environment exists
10
+ if (Test-Path ".venv\Scripts\Activate.ps1") {
11
+ Write-Host "Activating virtual environment..." -ForegroundColor Green
12
+ & .venv\Scripts\Activate.ps1
13
+ } else {
14
+ Write-Host "No virtual environment found. Using system Python." -ForegroundColor Yellow
15
+ }
16
+
17
+ Write-Host ""
18
+ Write-Host "Launching Gradio interface..." -ForegroundColor Green
19
+ Write-Host "The web UI will open at: http://127.0.0.1:7860" -ForegroundColor Yellow
20
+ Write-Host ""
21
+ Write-Host "Press Ctrl+C to stop the server" -ForegroundColor Red
22
+ Write-Host "========================================" -ForegroundColor Cyan
23
+ Write-Host ""
24
+
25
+ # Launch the app
26
+ python app.py
27
+
28
+ # Keep window open on error
29
+ if ($LASTEXITCODE -ne 0) {
30
+ Write-Host ""
31
+ Write-Host "An error occurred. Press any key to exit..." -ForegroundColor Red
32
+ $null = $Host.UI.RawUI.ReadKey("NoEcho,IncludeKeyDown")
33
+ }
main.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from dataclasses import asdict
4
+
5
+ from pipeline import HealthQueryPipeline
6
+
7
+ EXIT_COMMANDS = ["exit", "quit"]
8
+ PROMPT = "\nQuery> "
9
+
10
+ def main(pipeline: HealthQueryPipeline, k: int) -> None:
11
+ print(f"(Ctrl-D or 'quit' to exit)")
12
+
13
+ while True:
14
+ try:
15
+ query = input(PROMPT).strip()
16
+ if not query or query.lower() in EXIT_COMMANDS:
17
+ break
18
+
19
+ # Show index status
20
+ curr, total = pipeline.get_index_progress()
21
+ if total > 0:
22
+ pct = int((curr / total) * 100)
23
+ if pct < 100:
24
+ print(f"[Index: {pct}% loaded]")
25
+
26
+ # Use the pipeline to get results
27
+ result = pipeline.predict(query, k=k)
28
+
29
+ classification = result["classification"]
30
+ prediction = classification["prediction"]
31
+
32
+ print(f"\nTriaging query as {prediction}")
33
+ print(f"\nConfidence:")
34
+ for cat, prob in classification["probabilities"].items():
35
+ percent = prob * 100
36
+ print(f" {cat}: {percent:3.2f}%")
37
+ print()
38
+
39
+ if "medical" == prediction:
40
+ hits = result["retrieval"]
41
+ print(f"Found {len(hits)} matching medical documents\n")
42
+
43
+ if not hits:
44
+ print("No medical documents found.\n")
45
+ continue
46
+
47
+ for i, hit in enumerate(hits, 1):
48
+ # hit is already a dict from the pipeline
49
+ print(json.dumps(hit, indent=2, ensure_ascii=False))
50
+ else:
51
+ print(f"TODO: handle queries of type {prediction}")
52
+ continue
53
+
54
+ except EOFError:
55
+ print("\nBye!")
56
+ break
57
+
58
+ except KeyboardInterrupt:
59
+ print("\nBye!")
60
+ break
61
+
62
+
63
+ if __name__ == "__main__":
64
+ ap = argparse.ArgumentParser(
65
+ description="Hybrid retrieval (BM25 + Dense + RRF, optional re-rank)"
66
+ )
67
+ ap.add_argument("--k", type=int, default=10, help="Number of results to return")
68
+ ap.add_argument(
69
+ "--rerank", action="store_true",
70
+ help="Use cross-encoder reranker (slower, usually better)"
71
+ )
72
+ args = ap.parse_args()
73
+
74
+ # Initialize pipeline
75
+ pipeline = HealthQueryPipeline(use_reranker=args.rerank)
76
+ pipeline.initialize()
77
+
78
+ main(pipeline, k=args.k)
pipeline.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from dataclasses import asdict
3
+ from typing import List, Dict, Any, Optional
4
+
5
+ from sentence_transformers import SentenceTransformer
6
+ from classifier.head import ClassifierHead
7
+ from classifier.infer import predict_query
8
+ from classifier.utils import get_models
9
+ from retriever import Retriever
10
+ from team.candidates import get_candidates, _available
11
+ from config import settings
12
+
13
+ class HealthQueryPipeline:
14
+ def __init__(self, use_reranker: bool = False):
15
+ self.use_reranker = use_reranker
16
+ self.embedding_model: Optional[SentenceTransformer] = None
17
+ self.classifier: Optional[ClassifierHead] = None
18
+ self.retriever: Optional[Retriever] = None
19
+ self.is_initialized = False
20
+
21
+ def initialize(self):
22
+ """Loads models and initializes the retriever."""
23
+ if self.is_initialized:
24
+ return
25
+
26
+ print(f"Loading embedding model: {settings.MODEL_NAME}...")
27
+ self.embedding_model, self.classifier = get_models(model_id=settings.CLASSIFIER_NAME)
28
+ print("Model loaded.")
29
+
30
+ print("Initializing retriever...")
31
+ cfg = _available(settings.CORPORA_CONFIG)
32
+ if not cfg:
33
+ raise RuntimeError("No corpora files found in data/corpora. Build them first.")
34
+
35
+ self.retriever = Retriever(
36
+ corpora_config=cfg,
37
+ use_reranker=self.use_reranker,
38
+ embedding_model=self.embedding_model
39
+ )
40
+ print("Retriever initialized.")
41
+ self.is_initialized = True
42
+
43
+ def predict(self, query: str, k: int = 10) -> Dict[str, Any]:
44
+ """
45
+ Runs the full pipeline: Classification -> Retrieval (if medical).
46
+ """
47
+ if not self.is_initialized:
48
+ self.initialize()
49
+
50
+ classification = predict_query(
51
+ text=[query],
52
+ embedding_model=self.embedding_model,
53
+ classifier_head=self.classifier,
54
+ )
55
+
56
+ predictions = classification["prediction"]
57
+ result = {
58
+ "query": query,
59
+ "classification": {
60
+ "prediction": predictions[0],
61
+ "probabilities": {
62
+ cat: prob
63
+ for cat, prob in zip(settings.CATEGORIES, classification['probabilities'])
64
+ }
65
+ },
66
+ "retrieval": []
67
+ }
68
+
69
+ if "medical" in predictions:
70
+ hits = get_candidates(
71
+ query=query,
72
+ retriever=self.retriever,
73
+ k_retrieve=k,
74
+ )
75
+ result["retrieval"] = [asdict(hit) for hit in hits]
76
+
77
+ return result
78
+
79
+ def get_index_progress(self):
80
+ """Returns (current, total) of the underlying index."""
81
+ if not self.retriever:
82
+ return 0, 0
83
+ return self.retriever.get_index_progress()
reason_data_analysis.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple script to analyze healthcare reason data processing
3
+ """
4
+
5
+ import pandas as pd
6
+ import sys
7
+ import os
8
+
9
+ # Add current directory to path
10
+ sys.path.append('.')
11
+
12
+ def test_data_loading():
13
+ """Test loading and processing the healthcare reason data"""
14
+
15
+ print("Testing Healthcare Reason Data Processing")
16
+ print("=" * 40)
17
+
18
+ # Load the data
19
+ try:
20
+ df = pd.read_excel('data/reason_for_visit_data.xlsx')
21
+ print(f"✅ Successfully loaded {len(df)} records")
22
+ except Exception as e:
23
+ print(f"❌ Error loading data: {e}")
24
+ return False
25
+
26
+ # Analyze the data
27
+ print(f"\nDataset Info:")
28
+ print(f"Shape: {df.shape}")
29
+ print(f"Columns: {list(df.columns)}")
30
+
31
+ # Show reason distribution
32
+ print(f"\nTop 10 Reasons for Visit:")
33
+ top_reasons = df['Reason For Visit'].value_counts().head(10)
34
+ for reason, count in top_reasons.items():
35
+ print(f" {reason}: {count}")
36
+
37
+ # Test categorization logic
38
+ def map_reason_to_category(reason: str) -> str:
39
+ """Simple categorization logic"""
40
+ reason_lower = reason.lower()
41
+
42
+ if any(word in reason_lower for word in ['routine', 'nail care', 'calluses']):
43
+ return "ROUTINE_CARE"
44
+ elif any(word in reason_lower for word in ['pain', 'ache', 'sore']):
45
+ return "PAIN_CONDITIONS"
46
+ elif any(word in reason_lower for word in ['sprain', 'wound', 'injury']):
47
+ return "INJURIES"
48
+ elif any(word in reason_lower for word in ['ingrown', 'toenail', 'callus']):
49
+ return "SKIN_CONDITIONS"
50
+ elif any(word in reason_lower for word in ['flat feet', 'plantar', 'fasciitis', 'achilles']):
51
+ return "STRUCTURAL_ISSUES"
52
+ elif any(word in reason_lower for word in ['injection', 'surgical', 'consult', 'postop']):
53
+ return "PROCEDURES"
54
+ else:
55
+ return "PAIN_CONDITIONS" # Default
56
+
57
+ # Apply categorization
58
+ df['Category'] = df['Reason For Visit'].apply(map_reason_to_category)
59
+
60
+ print(f"\nCategory Distribution:")
61
+ category_counts = df['Category'].value_counts()
62
+ for category, count in category_counts.items():
63
+ percentage = (count / len(df)) * 100
64
+ print(f" {category}: {count} ({percentage:.1f}%)")
65
+
66
+ # Show examples for each category
67
+ print(f"\nExample reasons by category:")
68
+ for category in category_counts.index:
69
+ examples = df[df['Category'] == category]['Reason For Visit'].head(3).tolist()
70
+ print(f" {category}:")
71
+ for example in examples:
72
+ print(f" - {example}")
73
+
74
+ return True
75
+
76
+ if __name__ == "__main__":
77
+ success = test_data_loading()
78
+ if success:
79
+ print("\n✅ Healthcare reason data analysis completed successfully!")
80
+ else:
81
+ print("\n❌ Healthcare reason data analysis failed!")
requirements-admin.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Requirements for Administrative Query Classifier
2
+ pandas>=1.5.0
3
+ scikit-learn>=1.0.0
4
+ setfit>=1.0.0
5
+ sentence-transformers>=2.0.0
6
+ datasets>=2.0.0
7
+ matplotlib>=3.5.0
8
+ seaborn>=0.11.0
9
+ numpy>=1.21.0
requirements-train.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ matplotlib
2
+ numpy
3
+ pandas
4
+ sentence-transformers
5
+ torch
6
+ huggingface_hub
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets
2
+ pyarrow
3
+ jsonlines
4
+ tqdm
5
+ rank-bm25
6
+ faiss-cpu
7
+ sentence-transformers
8
+ numpy
9
+ scipy
10
+ scikit-learn
11
+ torch
12
+ huggingface_hub
13
+ gradio
14
+ streamlit
retriever/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .search import Retriever
retriever/data_schemas.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class Doc:
5
+ id: str
6
+ text: str
7
+ title: str | None = None
8
+ meta: dict | None = None
retriever/index_bm25.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rank_bm25 import BM25Okapi
2
+ from .utils import tokenize
3
+
4
+ class BM25Index:
5
+ def __init__(self, docs):
6
+ self.docs = docs
7
+ self.corpus_tokens = [tokenize(d.text) for d in docs]
8
+ self.bm25 = BM25Okapi(self.corpus_tokens)
9
+
10
+ def search(self, query: str, k: int = 50):
11
+ q = tokenize(query)
12
+ scores = self.bm25.get_scores(q)
13
+ top = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k]
14
+ return [(self.docs[i], float(scores[i])) for i in top]
retriever/index_dense.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # retriever/index_dense.py
2
+ import os
3
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
4
+
5
+ import hashlib
6
+ import threading
7
+ import numpy as np
8
+ import pickle
9
+ import torch
10
+ from pathlib import Path
11
+ from sentence_transformers import SentenceTransformer
12
+ from classifier.utils import DEVICE
13
+
14
+ try:
15
+ import faiss # type: ignore
16
+ _HAS_FAISS = True
17
+ except Exception:
18
+ _HAS_FAISS = False
19
+
20
+ def _chunks(lst, n):
21
+ for i in range(0, len(lst), n):
22
+ yield lst[i:i+n]
23
+
24
+ def _compute_cache_key(docs, model_name):
25
+ """Compute a hash key for caching based on documents and model."""
26
+ # Create a hash from document IDs/texts and model name
27
+ doc_ids = "".join([d.id for d in docs])
28
+ content = f"{model_name}:{doc_ids}"
29
+ return hashlib.md5(content.encode()).hexdigest()
30
+
31
+ class DenseIndex:
32
+ def __init__(self, docs, model_name="sentence-transformers/embeddinggemma-300m-medical",
33
+ batch_size=64, embedding_model=None, cache_dir=".cache/embeddings"):
34
+ self.docs = docs
35
+ self.batch_size = batch_size
36
+ self.cache_dir = cache_dir
37
+
38
+ # Thread safety
39
+ self.lock = threading.Lock()
40
+ self.ready_count = 0
41
+ self.emb_batches = [] # List of numpy arrays for fallback
42
+
43
+ torch.set_num_threads(1)
44
+ if embedding_model:
45
+ self.model = embedding_model
46
+ self.device = self.model.device
47
+ actual_model_name = getattr(self.model, 'model_card_data', {}).get('base_model', model_name)
48
+ if hasattr(self.model, '_model_card_vars') and 'model_id' in self.model._model_card_vars:
49
+ actual_model_name = self.model._model_card_vars['model_id']
50
+ else:
51
+ self.model = SentenceTransformer(model_name, device=DEVICE)
52
+ self.device = DEVICE
53
+ actual_model_name = model_name
54
+
55
+ self.cache_key = _compute_cache_key(docs, actual_model_name)
56
+ self.cache_path = Path(cache_dir) / f"{self.cache_key}.pkl"
57
+
58
+ # Initialize index structure
59
+ if _HAS_FAISS:
60
+ # We need to know dimension to init FAISS.
61
+ # We'll init it when the first batch arrives or if we load full cache.
62
+ self.index = None
63
+ else:
64
+ self.index = None
65
+
66
+ # Start background ingestion
67
+ self.ingest_thread = threading.Thread(target=self._ingest_embeddings, daemon=True)
68
+ self.ingest_thread.start()
69
+
70
+ def _generate_embeddings(self):
71
+ """Yields batches of embeddings from cache or computation."""
72
+ texts = [d.text for d in self.docs]
73
+
74
+ # 1. Try full cache first
75
+ if self.cache_path.exists():
76
+ print(f"Loading embeddings from cache: {self.cache_path}")
77
+ try:
78
+ with open(self.cache_path, 'rb') as f:
79
+ full_emb = pickle.load(f)
80
+ print(f"✓ Loaded {len(full_emb)} cached embeddings")
81
+ # Yield as a single large batch
82
+ yield full_emb
83
+ return
84
+ except Exception as e:
85
+ print(f"Cache load failed: {e}, recomputing...")
86
+
87
+ # 2. Partial cache logic
88
+ partial_cache_path = self.cache_path.parent / f"{self.cache_path.stem}.partial.pkl"
89
+ start_index = 0
90
+ existing_embs = []
91
+
92
+ if partial_cache_path.exists():
93
+ try:
94
+ with open(partial_cache_path, 'rb') as f:
95
+ existing_embs = pickle.load(f)
96
+
97
+ # Yield existing chunks
98
+ # We assume existing_embs is a list of batches from previous run
99
+ # But wait, previous implementation saved list of batches.
100
+ # Let's verify if it saved list of batches or vstacked array.
101
+ # Previous impl: pickle.dump(embs, f) where embs is list of arrays.
102
+
103
+ for batch in existing_embs:
104
+ yield batch
105
+
106
+ start_index = sum(len(e) for e in existing_embs)
107
+ except Exception as e:
108
+ existing_embs = []
109
+ start_index = 0
110
+
111
+ # 3. Compute remaining
112
+ texts_to_process = texts[start_index:]
113
+ if not texts_to_process:
114
+ return
115
+
116
+ # We need to keep track of all embs (existing + new) to save partial/full cache
117
+ # But `existing_embs` might be large.
118
+ # We will append new batches to `existing_embs` locally to save partials.
119
+
120
+ with torch.inference_mode():
121
+ total_processed = start_index
122
+ total_batches = (len(texts) + self.batch_size - 1) // self.batch_size
123
+ start_batch = len(existing_embs)
124
+
125
+ for i, part in enumerate(_chunks(texts_to_process, self.batch_size), 1):
126
+ part_emb = self.model.encode(
127
+ part,
128
+ batch_size=self.batch_size,
129
+ normalize_embeddings=True,
130
+ convert_to_numpy=True,
131
+ show_progress_bar=False,
132
+ device=self.device,
133
+ )
134
+ batch_emb = part_emb.astype(np.float32)
135
+ yield batch_emb
136
+
137
+ existing_embs.append(batch_emb)
138
+ total_processed += len(part)
139
+
140
+ # Save partial
141
+ with open(partial_cache_path, 'wb') as f:
142
+ pickle.dump(existing_embs, f)
143
+
144
+ def _ingest_embeddings(self):
145
+ """Background thread to ingest embeddings from generator."""
146
+ all_embs = []
147
+
148
+ for batch_emb in self._generate_embeddings():
149
+ with self.lock:
150
+ if _HAS_FAISS:
151
+ if self.index is None:
152
+ d = batch_emb.shape[1]
153
+ self.index = faiss.IndexFlatIP(d)
154
+ self.index.add(batch_emb)
155
+
156
+ # We also keep track for fallback or saving
157
+ self.emb_batches.append(batch_emb)
158
+ self.ready_count += len(batch_emb)
159
+
160
+ all_embs.append(batch_emb)
161
+
162
+ # Finalize
163
+ full_emb = np.vstack(all_embs).astype(np.float32)
164
+
165
+ # Save full cache
166
+ self.cache_path.parent.mkdir(parents=True, exist_ok=True)
167
+ with open(self.cache_path, 'wb') as f:
168
+ pickle.dump(full_emb, f)
169
+ print(f"✓ Saved embeddings to cache: {self.cache_path}")
170
+
171
+ # Cleanup partial
172
+ partial_cache_path = self.cache_path.parent / f"{self.cache_path.stem}.partial.pkl"
173
+ if partial_cache_path.exists():
174
+ partial_cache_path.unlink()
175
+
176
+ def search(self, query: str, k: int = 50):
177
+ qv = self.model.encode(
178
+ [query],
179
+ normalize_embeddings=True,
180
+ convert_to_numpy=True,
181
+ show_progress_bar=False,
182
+ device=self.device,
183
+ ).astype(np.float32)[0]
184
+
185
+ with self.lock:
186
+ current_count = self.ready_count
187
+ if current_count == 0:
188
+ print("Warning: Index not yet initialized, returning empty results.")
189
+ return []
190
+
191
+ # If we have partial data, we search it.
192
+ if _HAS_FAISS and self.index is not None:
193
+ # FAISS index is updated incrementally
194
+ D, I = self.index.search(qv.reshape(1, -1), min(k, current_count))
195
+ return [(self.docs[int(i)], float(D[0][j])) for j, i in enumerate(I[0]) if i != -1]
196
+
197
+ # NumPy fallback
198
+ # We might have multiple batches, need to stack them for search
199
+ # Optimization: cache the stacked version if it hasn't changed?
200
+ # For now, just stack what we have.
201
+ curr_emb = np.vstack(self.emb_batches)
202
+
203
+ sims = curr_emb @ qv
204
+ effective_k = min(k, len(sims))
205
+
206
+ if effective_k >= len(sims):
207
+ order = np.argsort(-sims)
208
+ else:
209
+ idx = np.argpartition(-sims, kth=effective_k-1)[:effective_k]
210
+ order = idx[np.argsort(-sims[idx])]
211
+
212
+ return [(self.docs[int(i)], float(sims[int(i)])) for i in order]
213
+
214
+ def get_progress(self):
215
+ """Returns (current_count, total_count) of indexed documents."""
216
+ with self.lock:
217
+ return self.ready_count, len(self.docs)
retriever/ingest.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, pathlib
2
+ from .data_schemas import Doc
3
+
4
+ def load_jsonl(path: str, text_fields=("question","answer")):
5
+ p = pathlib.Path(path)
6
+ docs = []
7
+ with p.open(encoding="utf-8") as f:
8
+ for i, line in enumerate(f):
9
+ row = json.loads(line)
10
+ # Collect fields; allow either "text" or joined fields
11
+ if "text" in row and row["text"]:
12
+ combined = row["text"]
13
+ else:
14
+ combined = " ".join([row.get(tf, "") for tf in text_fields]).strip()
15
+ title = row.get("title") or row.get("category") or ""
16
+ docs.append(Doc(
17
+ id=str(row.get("id", f"{p.stem}:{i}")),
18
+ text=combined,
19
+ title=title,
20
+ meta=row
21
+ ))
22
+ return docs
retriever/rrf.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ def rrf(rank_lists, k=10, K=60):
4
+ scores = defaultdict(float)
5
+ id2doc = {}
6
+ for rl in rank_lists:
7
+ for r, (doc, _) in enumerate(rl):
8
+ id2doc[doc.id] = doc
9
+ scores[doc.id] += 1.0 / (K + r + 1)
10
+ ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:k]
11
+ return [(id2doc[i], s) for i, s in ranked]
retriever/search.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .index_bm25 import BM25Index
2
+ from .index_dense import DenseIndex
3
+ from .rrf import rrf
4
+ try:
5
+ from .rerank import CrossEncoderReranker
6
+ except Exception:
7
+ CrossEncoderReranker = None
8
+ from .ingest import load_jsonl
9
+
10
+ class Retriever:
11
+ def __init__(self, corpora_config, use_reranker=False, embedding_model=None):
12
+ self.corpora = {}
13
+ docs_all = []
14
+ for name, cfg in corpora_config.items():
15
+ docs = load_jsonl(cfg["path"], tuple(cfg.get("text_fields", ("question","answer"))))
16
+ self.corpora[name] = docs
17
+ docs_all.extend(docs)
18
+ self.bm25 = BM25Index(docs_all)
19
+ self.dense = DenseIndex(docs_all, embedding_model=embedding_model)
20
+ self.reranker = CrossEncoderReranker() if (use_reranker and CrossEncoderReranker) else None
21
+
22
+ def retrieve(self, query, k=10, for_ui=True):
23
+ bm = self.bm25.search(query, k=100)
24
+ de = self.dense.search(query, k=100)
25
+ fused = rrf([bm, de], k=max(k, 20))
26
+ if self.reranker:
27
+ reranked = self.reranker.rerank(query, [d for d, _ in fused])[:k]
28
+ results = [(d, float(s)) for d, s in reranked]
29
+ else:
30
+ results = fused[:k]
31
+ if not for_ui:
32
+ return results
33
+ return [{
34
+ "id": d.id,
35
+ "title": d.title,
36
+ "snippet": d.text[:300] + ("..." if len(d.text) > 300 else ""),
37
+ "score": s,
38
+ "meta": d.meta
39
+ } for d, s in results]
40
+
41
+ def get_index_progress(self):
42
+ """Returns (current, total) from dense index."""
43
+ return self.dense.get_progress()
retriever/utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ # Simple non-word splitter (keeps letters/numbers, splits on punctuation/whitespace)
4
+ _WS = re.compile(r"\W+", flags=re.UNICODE)
5
+
6
+ def tokenize(s: str) -> list[str]:
7
+ """
8
+ Lowercase + split on non-word chars. Returns [] for None/empty.
9
+ Used by BM25 to build the tokenized corpus and query.
10
+ """
11
+ if not s:
12
+ return []
13
+ return [t for t in _WS.split(s.lower()) if t]
14
+
15
+
scripts/query.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+
3
+ import json
4
+ import readline
5
+ import sys
6
+
7
+ from pyserini.search.lucene import LuceneSearcher
8
+
9
+
10
+ def main():
11
+ index_dir = sys.argv[1] if len(sys.argv) > 1 else "indexes/pubmed"
12
+
13
+ searcher = LuceneSearcher(index_dir)
14
+
15
+ print(f"Loaded {searcher.num_docs} documents from {index_dir}")
16
+ print(f"(Ctrl-D or 'quit' to exit)\n")
17
+
18
+ while True:
19
+ try:
20
+ query = input("PubMed> ").strip()
21
+ if not query or query.lower() in ['quit', 'exit']:
22
+ break
23
+
24
+ hits = searcher.search(query, k=10)
25
+
26
+ print(f"{len(hits)}/{searcher.num_docs} matching documents found\n")
27
+
28
+ if not hits:
29
+ print("No results found.\n")
30
+
31
+ continue
32
+
33
+ for i, hit in enumerate(hits, 1):
34
+ doc = searcher.doc(hit.docid)
35
+
36
+ raw = json.loads(doc.raw())
37
+
38
+ title = raw.get('title', '')
39
+ contents = raw.get('contents', '')
40
+
41
+ abstract = contents[len(title):] if contents.startswith(title) else contents
42
+
43
+ print(f"{i}. PMID {hit.docid} \"{title}\" (score: {hit.score:.4f})")
44
+ print(f" {abstract[:120]}...\n")
45
+
46
+ except EOFError:
47
+ print("\nBye!")
48
+
49
+ break
50
+
51
+ except KeyboardInterrupt:
52
+ print("\nBye!")
53
+
54
+ break
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()