imurra commited on
Commit
2b376ce
Β·
verified Β·
1 Parent(s): 5ef4c55

new appy to ahve med-gemini and medQa

Browse files
Files changed (1) hide show
  1. app.py +424 -103
app.py CHANGED
@@ -1,115 +1,436 @@
1
- import os
2
- os.environ['ANONYMIZED_TELEMETRY'] = 'False'
3
-
4
- import zipfile
5
- import chromadb
6
- from sentence_transformers import SentenceTransformer
7
  import gradio as gr
8
- from fastapi import FastAPI
9
- from pydantic import BaseModel
 
 
 
 
10
 
11
- # Extract and load database
12
- DB_PATH = "./medqa_db"
13
- if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
14
- with zipfile.ZipFile("./medqa_db.zip", 'r') as z:
15
- z.extractall(".")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- client = chromadb.PersistentClient(path=DB_PATH)
18
- collection = client.get_collection("medqa")
19
- model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
20
 
21
- # Search function
22
- def search(query, num_results=3):
23
- emb = model.encode(query).tolist()
24
- return collection.query(query_embeddings=[emb], n_results=int(num_results))
25
 
26
- # Gradio UI
27
- def ui_search(query, num_results=3):
28
- if not query.strip():
29
- return "Enter a query"
30
- try:
31
- r = search(query, num_results)
32
- out = ""
33
- for i in range(len(r['documents'][0])):
34
- out += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
35
-
36
- # Get the full question text
37
- question_text = r['documents'][0][i]
38
-
39
- # DEBUG: Show raw text
40
- out += "DEBUG RAW TEXT:\n"
41
- out += repr(question_text) + "\n"
42
- out += "="*60 + "\n\n"
43
-
44
- # Parse question and answer choices
45
- import re
46
- # Look for answer choices pattern (A. or A) followed by text)
47
- lines = question_text.split('\n')
48
- question_part = []
49
- choices_part = []
50
- in_choices = False
51
-
52
- for line in lines:
53
- # Check if line starts with A-E followed by . or )
54
- if re.match(r'^[A-E][\.\)]', line.strip()):
55
- in_choices = True
56
- choices_part.append(line)
57
- elif in_choices:
58
- # Continue collecting choices if they span multiple lines
59
- if line.strip() and not re.match(r'^[A-E][\.\)]', line.strip()):
60
- choices_part[-1] += " " + line.strip()
61
- elif re.match(r'^[A-E][\.\)]', line.strip()):
62
- choices_part.append(line)
63
- else:
64
- question_part.append(line)
65
-
66
- # Display question
67
- out += '\n'.join(question_part).strip() + "\n\n"
68
-
69
- # Display choices if found
70
- if choices_part:
71
- out += "Answer Choices:\n"
72
- for choice in choices_part:
73
- out += choice.strip() + "\n"
74
- out += "\n"
75
-
76
- out += f"Correct Answer: {r['metadatas'][0][i].get('answer', 'N/A')}\n"
77
- out += f"Similarity: {1 - r['distances'][0][i]:.3f}\n"
78
- return out
79
- except Exception as e:
80
- return f"Error: {e}"
81
 
82
- demo = gr.Interface(
83
- fn=ui_search,
84
- inputs=[
85
- gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia"),
86
- gr.Slider(1, 5, value=3, step=1, label="Results")
87
- ],
88
- outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
89
- title="MedQA Search",
90
- examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- # FastAPI
94
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- class SearchRequest(BaseModel):
97
- query: str
98
- num_results: int = 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- @app.post("/search_medqa")
101
- def api_search(req: SearchRequest):
102
- r = search(req.query, req.num_results)
103
- return {"results": [{
104
- "example_number": i+1,
105
- "question": r['documents'][0][i],
106
- "answer": r['metadatas'][0][i].get('answer', 'N/A'),
107
- "distance": r['distances'][0][i]
108
- } for i in range(len(r['documents'][0]))]}
109
 
110
- app = gr.mount_gradio_app(app, demo, path="/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- # Launch the server
113
  if __name__ == "__main__":
114
- import uvicorn
115
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import json
3
+ import zipfile
4
+ from pathlib import Path
5
+ import pandas as pd
6
+ from typing import Dict, List, Tuple
7
+ import random
8
 
9
+ class MedQADatabase:
10
+ """Handler for MedQA and Med-Gemini databases"""
11
+
12
+ def __init__(self, zip_path="medqa_databases.zip"):
13
+ self.data = {
14
+ 'medgemini': [],
15
+ 'medqa_train': [],
16
+ 'medqa_dev': [],
17
+ 'medqa_test': []
18
+ }
19
+ self.load_databases(zip_path)
20
+
21
+ def load_databases(self, zip_path):
22
+ """Load all databases from the ZIP file"""
23
+ print("πŸ“¦ Loading databases from ZIP...")
24
+
25
+ try:
26
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
27
+ # Extract to temporary directory
28
+ zip_ref.extractall('temp_data')
29
+
30
+ # Load Med-Gemini
31
+ medgemini_path = Path('temp_data/medqa_databases/med_gemini/medqa_relabelling.json')
32
+ if medgemini_path.exists():
33
+ with open(medgemini_path, 'r', encoding='utf-8') as f:
34
+ self.data['medgemini'] = json.load(f)
35
+ print(f"βœ… Loaded {len(self.data['medgemini'])} Med-Gemini questions")
36
+
37
+ # Load MedQA splits
38
+ medqa_base = Path('temp_data/medqa_databases/medqa_original')
39
+ for split in ['train', 'dev', 'test']:
40
+ split_path = medqa_base / f"{split}.json"
41
+ if split_path.exists():
42
+ with open(split_path, 'r', encoding='utf-8') as f:
43
+ self.data[f'medqa_{split}'] = json.load(f)
44
+ print(f"βœ… Loaded {len(self.data[f'medqa_{split}'])} MedQA {split} questions")
45
+
46
+ except Exception as e:
47
+ print(f"❌ Error loading databases: {e}")
48
+ raise
49
+
50
+ def get_stats(self) -> str:
51
+ """Get database statistics"""
52
+ stats = "## πŸ“Š Database Statistics\n\n"
53
+ stats += f"**Med-Gemini**: {len(self.data['medgemini']):,} questions\n\n"
54
+ stats += f"**MedQA Original**:\n"
55
+ stats += f"- Training: {len(self.data['medqa_train']):,} questions\n"
56
+ stats += f"- Development: {len(self.data['medqa_dev']):,} questions\n"
57
+ stats += f"- Test: {len(self.data['medqa_test']):,} questions\n"
58
+ stats += f"- **Total**: {sum(len(self.data[f'medqa_{s}']) for s in ['train', 'dev', 'test']):,} questions\n\n"
59
+ stats += f"**Grand Total**: {sum(len(v) for v in self.data.values()):,} questions"
60
+ return stats
61
+
62
+ def get_question(self, dataset: str, index: int) -> Dict:
63
+ """Get a specific question from a dataset"""
64
+ try:
65
+ return self.data[dataset][index]
66
+ except (KeyError, IndexError):
67
+ return None
68
+
69
+ def search_questions(self, query: str, dataset: str = 'all', max_results: int = 50) -> List[Tuple[str, int, str]]:
70
+ """Search questions by keyword"""
71
+ results = []
72
+ query_lower = query.lower()
73
+
74
+ datasets_to_search = list(self.data.keys()) if dataset == 'all' else [dataset]
75
+
76
+ for ds in datasets_to_search:
77
+ for idx, q in enumerate(self.data[ds]):
78
+ # Search in question text
79
+ question_text = q.get('question', q.get('Question', ''))
80
+ if query_lower in question_text.lower():
81
+ preview = question_text[:100] + "..." if len(question_text) > 100 else question_text
82
+ results.append((ds, idx, preview))
83
+
84
+ if len(results) >= max_results:
85
+ return results
86
+
87
+ return results
88
 
89
+ # Initialize database
90
+ print("πŸš€ Initializing MedQA Explorer...")
91
+ db = MedQADatabase()
92
 
93
+ # ============================================================================
94
+ # GRADIO INTERFACE FUNCTIONS
95
+ # ============================================================================
 
96
 
97
+ def format_question_display(question_data: Dict, dataset: str) -> str:
98
+ """Format question data for display"""
99
+
100
+ if not question_data:
101
+ return "❌ Question not found"
102
+
103
+ # Handle different data formats
104
+ if dataset == 'medgemini':
105
+ return format_medgemini_question(question_data)
106
+ else:
107
+ return format_medqa_question(question_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ def format_medgemini_question(q: Dict) -> str:
110
+ """Format Med-Gemini question"""
111
+ html = f"""
112
+ <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;">
113
+ <h2 style="color: white; margin: 0;">πŸ”¬ Med-Gemini Question</h2>
114
+ </div>
115
+
116
+ <div style="background: #f8f9fa; padding: 20px; border-radius: 8px; margin-bottom: 20px;">
117
+ <h3>πŸ“‹ Question</h3>
118
+ <p style="font-size: 16px; line-height: 1.6;">{q.get('question', 'N/A')}</p>
119
+ </div>
120
+
121
+ <div style="background: #fff; padding: 20px; border-radius: 8px; margin-bottom: 20px; border: 2px solid #e0e0e0;">
122
+ <h3>πŸ”€ Answer Options</h3>
123
+ """
124
+
125
+ # Display options
126
+ options = q.get('options', {})
127
+ correct_answer = q.get('answer_idx', 'N/A')
128
+
129
+ option_labels = ['A', 'B', 'C', 'D', 'E']
130
+ for label in option_labels:
131
+ option_key = f'opa' if label == 'A' else f'op{label.lower()}'
132
+ if option_key in options:
133
+ is_correct = (label == correct_answer)
134
+ color = '#d4edda' if is_correct else '#fff'
135
+ icon = 'βœ…' if is_correct else 'β­•'
136
+
137
+ html += f"""
138
+ <div style="background: {color}; padding: 12px; margin: 8px 0; border-radius: 5px; border: 1px solid #ccc;">
139
+ {icon} <strong>{label}.</strong> {options[option_key]}
140
+ </div>
141
+ """
142
+
143
+ html += "</div>"
144
+
145
+ # Show correct answer
146
+ html += f"""
147
+ <div style="background: #d4edda; padding: 15px; border-radius: 8px; margin-bottom: 20px; border-left: 4px solid #28a745;">
148
+ <h3 style="margin-top: 0;">βœ… Correct Answer</h3>
149
+ <p style="font-size: 18px; font-weight: bold; margin: 0;">{correct_answer}</p>
150
+ </div>
151
+ """
152
+
153
+ # Show explanation if available
154
+ explanation = q.get('explanation', q.get('Explanation', ''))
155
+ if explanation:
156
+ html += f"""
157
+ <div style="background: #e7f3ff; padding: 20px; border-radius: 8px; border-left: 4px solid #2196F3;">
158
+ <h3 style="margin-top: 0;">πŸ’‘ Explanation</h3>
159
+ <p style="line-height: 1.6;">{explanation}</p>
160
+ </div>
161
+ """
162
+
163
+ return html
164
+
165
+ def format_medqa_question(q: Dict) -> str:
166
+ """Format MedQA original question"""
167
+ html = f"""
168
+ <div style="background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;">
169
+ <h2 style="color: white; margin: 0;">πŸ“š MedQA USMLE Question</h2>
170
+ </div>
171
+
172
+ <div style="background: #f8f9fa; padding: 20px; border-radius: 8px; margin-bottom: 20px;">
173
+ <h3>πŸ“‹ Question</h3>
174
+ <p style="font-size: 16px; line-height: 1.6;">{q.get('question', 'N/A')}</p>
175
+ </div>
176
+
177
+ <div style="background: #fff; padding: 20px; border-radius: 8px; margin-bottom: 20px; border: 2px solid #e0e0e0;">
178
+ <h3>πŸ”€ Answer Options</h3>
179
+ """
180
+
181
+ # Display options
182
+ options = q.get('options', {})
183
+ correct_answer = q.get('answer_idx', 'N/A')
184
+
185
+ for key, value in options.items():
186
+ label = key.replace('op', '').upper() if key.startswith('op') else key
187
+ is_correct = (label == correct_answer)
188
+ color = '#d4edda' if is_correct else '#fff'
189
+ icon = 'βœ…' if is_correct else 'β­•'
190
+
191
+ html += f"""
192
+ <div style="background: {color}; padding: 12px; margin: 8px 0; border-radius: 5px; border: 1px solid #ccc;">
193
+ {icon} <strong>{label}.</strong> {value}
194
+ </div>
195
+ """
196
+
197
+ html += "</div>"
198
+
199
+ # Show correct answer
200
+ html += f"""
201
+ <div style="background: #d4edda; padding: 15px; border-radius: 8px; margin-bottom: 20px; border-left: 4px solid #28a745;">
202
+ <h3 style="margin-top: 0;">βœ… Correct Answer</h3>
203
+ <p style="font-size: 18px; font-weight: bold; margin: 0;">{correct_answer}</p>
204
+ </div>
205
+ """
206
+
207
+ # Show metamap if available
208
+ metamap = q.get('metamap_phrases')
209
+ if metamap:
210
+ html += f"""
211
+ <div style="background: #fff3cd; padding: 15px; border-radius: 8px; border-left: 4px solid #ffc107;">
212
+ <h3 style="margin-top: 0;">πŸ₯ Medical Concepts (MetaMap)</h3>
213
+ <p style="line-height: 1.6;">{', '.join(metamap)}</p>
214
+ </div>
215
+ """
216
+
217
+ return html
218
+
219
+ def browse_questions(dataset: str, index: int) -> Tuple[str, str]:
220
+ """Browse questions by index"""
221
+ total = len(db.data.get(dataset, []))
222
+
223
+ if total == 0:
224
+ return "❌ No questions in this dataset", f"Dataset: {dataset} (empty)"
225
+
226
+ # Clamp index to valid range
227
+ index = max(0, min(index, total - 1))
228
+
229
+ question = db.get_question(dataset, index)
230
+ html = format_question_display(question, dataset)
231
+ info = f"πŸ“Š Question {index + 1} of {total} | Dataset: {dataset}"
232
+
233
+ return html, info
234
 
235
+ def random_question(dataset: str) -> Tuple[str, str, int]:
236
+ """Get a random question"""
237
+ total = len(db.data.get(dataset, []))
238
+
239
+ if total == 0:
240
+ return "❌ No questions in this dataset", f"Dataset: {dataset} (empty)", 0
241
+
242
+ index = random.randint(0, total - 1)
243
+ question = db.get_question(dataset, index)
244
+ html = format_question_display(question, dataset)
245
+ info = f"🎲 Random Question {index + 1} of {total} | Dataset: {dataset}"
246
+
247
+ return html, info, index
248
 
249
+ def search_interface(query: str, dataset: str) -> str:
250
+ """Search interface"""
251
+ if not query.strip():
252
+ return "πŸ’‘ Enter a search query to find questions"
253
+
254
+ results = db.search_questions(query, dataset)
255
+
256
+ if not results:
257
+ return f"❌ No results found for '{query}' in {dataset}"
258
+
259
+ html = f"""
260
+ <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;">
261
+ <h2 style="color: white; margin: 0;">πŸ” Search Results: "{query}"</h2>
262
+ <p style="color: white; margin: 5px 0 0 0;">Found {len(results)} results in {dataset}</p>
263
+ </div>
264
+ """
265
+
266
+ for ds, idx, preview in results[:20]: # Show top 20
267
+ dataset_name = ds.replace('_', ' ').title()
268
+ html += f"""
269
+ <div style="background: #fff; padding: 15px; margin: 10px 0; border-radius: 8px; border-left: 4px solid #667eea;">
270
+ <p style="margin: 0; color: #666; font-size: 12px;"><strong>{dataset_name}</strong> - Question #{idx + 1}</p>
271
+ <p style="margin: 5px 0 0 0;">{preview}</p>
272
+ </div>
273
+ """
274
+
275
+ if len(results) > 20:
276
+ html += f"<p>... and {len(results) - 20} more results</p>"
277
+
278
+ return html
279
 
280
+ # ============================================================================
281
+ # GRADIO APP
282
+ # ============================================================================
 
 
 
 
 
 
283
 
284
+ with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Database Explorer") as app:
285
+
286
+ gr.Markdown("""
287
+ # πŸ₯ MedQA Database Explorer
288
+
289
+ Explore medical question-answering databases including **Med-Gemini** and **MedQA USMLE**.
290
+ """)
291
+
292
+ # Statistics
293
+ with gr.Accordion("πŸ“Š Database Statistics", open=False):
294
+ gr.Markdown(db.get_stats())
295
+
296
+ # Main interface
297
+ with gr.Tabs():
298
+
299
+ # Browse Tab
300
+ with gr.Tab("πŸ“– Browse Questions"):
301
+ with gr.Row():
302
+ with gr.Column(scale=1):
303
+ dataset_dropdown = gr.Dropdown(
304
+ choices=['medgemini', 'medqa_train', 'medqa_dev', 'medqa_test'],
305
+ value='medgemini',
306
+ label="Select Database"
307
+ )
308
+
309
+ question_slider = gr.Slider(
310
+ minimum=0,
311
+ maximum=len(db.data['medgemini']) - 1,
312
+ value=0,
313
+ step=1,
314
+ label="Question Number"
315
+ )
316
+
317
+ with gr.Row():
318
+ prev_btn = gr.Button("⬅️ Previous", size="sm")
319
+ random_btn = gr.Button("🎲 Random", size="sm", variant="primary")
320
+ next_btn = gr.Button("Next ➑️", size="sm")
321
+
322
+ info_text = gr.Textbox(label="Info", interactive=False)
323
+
324
+ with gr.Column(scale=2):
325
+ question_display = gr.HTML()
326
+
327
+ # Update slider max when dataset changes
328
+ def update_slider(dataset):
329
+ max_val = len(db.data.get(dataset, [])) - 1
330
+ return gr.Slider(maximum=max_val, value=0)
331
+
332
+ dataset_dropdown.change(
333
+ fn=update_slider,
334
+ inputs=[dataset_dropdown],
335
+ outputs=[question_slider]
336
+ )
337
+
338
+ # Browse functions
339
+ def show_question(dataset, index):
340
+ return browse_questions(dataset, int(index))
341
+
342
+ question_slider.change(
343
+ fn=show_question,
344
+ inputs=[dataset_dropdown, question_slider],
345
+ outputs=[question_display, info_text]
346
+ )
347
+
348
+ dataset_dropdown.change(
349
+ fn=show_question,
350
+ inputs=[dataset_dropdown, question_slider],
351
+ outputs=[question_display, info_text]
352
+ )
353
+
354
+ # Navigation buttons
355
+ def prev_question(dataset, index):
356
+ new_index = max(0, int(index) - 1)
357
+ html, info = browse_questions(dataset, new_index)
358
+ return html, info, new_index
359
+
360
+ def next_question(dataset, index):
361
+ max_idx = len(db.data.get(dataset, [])) - 1
362
+ new_index = min(max_idx, int(index) + 1)
363
+ html, info = browse_questions(dataset, new_index)
364
+ return html, info, new_index
365
+
366
+ prev_btn.click(
367
+ fn=prev_question,
368
+ inputs=[dataset_dropdown, question_slider],
369
+ outputs=[question_display, info_text, question_slider]
370
+ )
371
+
372
+ next_btn.click(
373
+ fn=next_question,
374
+ inputs=[dataset_dropdown, question_slider],
375
+ outputs=[question_display, info_text, question_slider]
376
+ )
377
+
378
+ random_btn.click(
379
+ fn=random_question,
380
+ inputs=[dataset_dropdown],
381
+ outputs=[question_display, info_text, question_slider]
382
+ )
383
+
384
+ # Load first question on start
385
+ app.load(
386
+ fn=show_question,
387
+ inputs=[dataset_dropdown, question_slider],
388
+ outputs=[question_display, info_text]
389
+ )
390
+
391
+ # Search Tab
392
+ with gr.Tab("πŸ” Search"):
393
+ with gr.Row():
394
+ search_query = gr.Textbox(
395
+ label="Search Query",
396
+ placeholder="Enter keywords (e.g., 'diabetes', 'heart failure', 'treatment')...",
397
+ scale=3
398
+ )
399
+ search_dataset = gr.Dropdown(
400
+ choices=['all', 'medgemini', 'medqa_train', 'medqa_dev', 'medqa_test'],
401
+ value='all',
402
+ label="Search In",
403
+ scale=1
404
+ )
405
+
406
+ search_btn = gr.Button("πŸ” Search", variant="primary")
407
+ search_results = gr.HTML()
408
+
409
+ search_btn.click(
410
+ fn=search_interface,
411
+ inputs=[search_query, search_dataset],
412
+ outputs=[search_results]
413
+ )
414
+
415
+ # Also search on Enter key
416
+ search_query.submit(
417
+ fn=search_interface,
418
+ inputs=[search_query, search_dataset],
419
+ outputs=[search_results]
420
+ )
421
+
422
+ gr.Markdown("""
423
+ ---
424
+ ### πŸ“š About the Databases
425
+
426
+ **Med-Gemini**: Expert-relabeled medical questions with detailed explanations from Google's Med-Gemini project.
427
+
428
+ **MedQA**: Original USMLE-style medical questions from the MedQA dataset.
429
+
430
+ ### πŸ”— Sources
431
+ - [Med-Gemini Paper](https://arxiv.org/abs/2404.18416)
432
+ - [MedQA Dataset](https://github.com/jind11/MedQA)
433
+ """)
434
 
 
435
  if __name__ == "__main__":
436
+ app.launch()