import gradio as gr import json import zipfile from pathlib import Path import pandas as pd from typing import Dict, List, Tuple import random class MedQADatabase: """Handler for MedQA and Med-Gemini databases""" def __init__(self, zip_path="medqa_databases.zip"): self.data = { 'medgemini': [], 'medqa_train': [], 'medqa_dev': [], 'medqa_test': [] } self.load_databases(zip_path) def load_databases(self, zip_path): """Load all databases from the ZIP file""" print("📦 Loading databases from ZIP...") try: with zipfile.ZipFile(zip_path, 'r') as zip_ref: # Extract to temporary directory zip_ref.extractall('temp_data') # Load Med-Gemini medgemini_path = Path('temp_data/medqa_databases/med_gemini/medqa_relabelling.json') if medgemini_path.exists(): with open(medgemini_path, 'r', encoding='utf-8') as f: self.data['medgemini'] = json.load(f) print(f"✅ Loaded {len(self.data['medgemini'])} Med-Gemini questions") # Load MedQA splits medqa_base = Path('temp_data/medqa_databases/medqa_original') for split in ['train', 'dev', 'test']: split_path = medqa_base / f"{split}.json" if split_path.exists(): with open(split_path, 'r', encoding='utf-8') as f: self.data[f'medqa_{split}'] = json.load(f) print(f"✅ Loaded {len(self.data[f'medqa_{split}'])} MedQA {split} questions") except Exception as e: print(f"❌ Error loading databases: {e}") raise def get_stats(self) -> str: """Get database statistics""" stats = "## 📊 Database Statistics\n\n" stats += f"**Med-Gemini**: {len(self.data['medgemini']):,} questions\n\n" stats += f"**MedQA Original**:\n" stats += f"- Training: {len(self.data['medqa_train']):,} questions\n" stats += f"- Development: {len(self.data['medqa_dev']):,} questions\n" stats += f"- Test: {len(self.data['medqa_test']):,} questions\n" stats += f"- **Total**: {sum(len(self.data[f'medqa_{s}']) for s in ['train', 'dev', 'test']):,} questions\n\n" stats += f"**Grand Total**: {sum(len(v) for v in self.data.values()):,} questions" return stats def get_question(self, dataset: str, index: int) -> Dict: """Get a specific question from a dataset""" try: return self.data[dataset][index] except (KeyError, IndexError): return None def search_questions(self, query: str, dataset: str = 'all', max_results: int = 50) -> List[Tuple[str, int, str]]: """Search questions by keyword""" results = [] query_lower = query.lower() datasets_to_search = list(self.data.keys()) if dataset == 'all' else [dataset] for ds in datasets_to_search: for idx, q in enumerate(self.data[ds]): # Search in question text question_text = q.get('question', q.get('Question', '')) if query_lower in question_text.lower(): preview = question_text[:100] + "..." if len(question_text) > 100 else question_text results.append((ds, idx, preview)) if len(results) >= max_results: return results return results # Initialize database print("🚀 Initializing MedQA Explorer...") db = MedQADatabase() # ============================================================================ # GRADIO INTERFACE FUNCTIONS # ============================================================================ def format_question_display(question_data: Dict, dataset: str) -> str: """Format question data for display""" if not question_data: return "❌ Question not found" # Handle different data formats if dataset == 'medgemini': return format_medgemini_question(question_data) else: return format_medqa_question(question_data) def format_medgemini_question(q: Dict) -> str: """Format Med-Gemini question""" html = f"""

🔬 Med-Gemini Question

📋 Question

{q.get('question', 'N/A')}

🔤 Answer Options

""" # Display options options = q.get('options', {}) correct_answer = q.get('answer_idx', 'N/A') option_labels = ['A', 'B', 'C', 'D', 'E'] for label in option_labels: option_key = f'opa' if label == 'A' else f'op{label.lower()}' if option_key in options: is_correct = (label == correct_answer) color = '#d4edda' if is_correct else '#fff' icon = '✅' if is_correct else '⭕' html += f"""
{icon} {label}. {options[option_key]}
""" html += "
" # Show correct answer html += f"""

✅ Correct Answer

{correct_answer}

""" # Show explanation if available explanation = q.get('explanation', q.get('Explanation', '')) if explanation: html += f"""

💡 Explanation

{explanation}

""" return html def format_medqa_question(q: Dict) -> str: """Format MedQA original question""" html = f"""

📚 MedQA USMLE Question

📋 Question

{q.get('question', 'N/A')}

🔤 Answer Options

""" # Display options options = q.get('options', {}) correct_answer = q.get('answer_idx', 'N/A') for key, value in options.items(): label = key.replace('op', '').upper() if key.startswith('op') else key is_correct = (label == correct_answer) color = '#d4edda' if is_correct else '#fff' icon = '✅' if is_correct else '⭕' html += f"""
{icon} {label}. {value}
""" html += "
" # Show correct answer html += f"""

✅ Correct Answer

{correct_answer}

""" # Show metamap if available metamap = q.get('metamap_phrases') if metamap: html += f"""

🏥 Medical Concepts (MetaMap)

{', '.join(metamap)}

""" return html def browse_questions(dataset: str, index: int) -> Tuple[str, str]: """Browse questions by index""" total = len(db.data.get(dataset, [])) if total == 0: return "❌ No questions in this dataset", f"Dataset: {dataset} (empty)" # Clamp index to valid range index = max(0, min(index, total - 1)) question = db.get_question(dataset, index) html = format_question_display(question, dataset) info = f"📊 Question {index + 1} of {total} | Dataset: {dataset}" return html, info def random_question(dataset: str) -> Tuple[str, str, int]: """Get a random question""" total = len(db.data.get(dataset, [])) if total == 0: return "❌ No questions in this dataset", f"Dataset: {dataset} (empty)", 0 index = random.randint(0, total - 1) question = db.get_question(dataset, index) html = format_question_display(question, dataset) info = f"🎲 Random Question {index + 1} of {total} | Dataset: {dataset}" return html, info, index def search_interface(query: str, dataset: str) -> str: """Search interface""" if not query.strip(): return "💡 Enter a search query to find questions" results = db.search_questions(query, dataset) if not results: return f"❌ No results found for '{query}' in {dataset}" html = f"""

🔍 Search Results: "{query}"

Found {len(results)} results in {dataset}

""" for ds, idx, preview in results[:20]: # Show top 20 dataset_name = ds.replace('_', ' ').title() html += f"""

{dataset_name} - Question #{idx + 1}

{preview}

""" if len(results) > 20: html += f"

... and {len(results) - 20} more results

" return html # ============================================================================ # GRADIO APP # ============================================================================ with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Database Explorer") as app: gr.Markdown(""" # 🏥 MedQA Database Explorer Explore medical question-answering databases including **Med-Gemini** and **MedQA USMLE**. """) # Statistics with gr.Accordion("📊 Database Statistics", open=False): gr.Markdown(db.get_stats()) # Main interface with gr.Tabs(): # Browse Tab with gr.Tab("📖 Browse Questions"): with gr.Row(): with gr.Column(scale=1): dataset_dropdown = gr.Dropdown( choices=['medgemini', 'medqa_train', 'medqa_dev', 'medqa_test'], value='medgemini', label="Select Database" ) question_slider = gr.Slider( minimum=0, maximum=len(db.data['medgemini']) - 1, value=0, step=1, label="Question Number" ) with gr.Row(): prev_btn = gr.Button("⬅️ Previous", size="sm") random_btn = gr.Button("🎲 Random", size="sm", variant="primary") next_btn = gr.Button("Next ➡️", size="sm") info_text = gr.Textbox(label="Info", interactive=False) with gr.Column(scale=2): question_display = gr.HTML() # Update slider max when dataset changes def update_slider(dataset): max_val = len(db.data.get(dataset, [])) - 1 return gr.Slider(maximum=max_val, value=0) dataset_dropdown.change( fn=update_slider, inputs=[dataset_dropdown], outputs=[question_slider] ) # Browse functions def show_question(dataset, index): return browse_questions(dataset, int(index)) question_slider.change( fn=show_question, inputs=[dataset_dropdown, question_slider], outputs=[question_display, info_text] ) dataset_dropdown.change( fn=show_question, inputs=[dataset_dropdown, question_slider], outputs=[question_display, info_text] ) # Navigation buttons def prev_question(dataset, index): new_index = max(0, int(index) - 1) html, info = browse_questions(dataset, new_index) return html, info, new_index def next_question(dataset, index): max_idx = len(db.data.get(dataset, [])) - 1 new_index = min(max_idx, int(index) + 1) html, info = browse_questions(dataset, new_index) return html, info, new_index prev_btn.click( fn=prev_question, inputs=[dataset_dropdown, question_slider], outputs=[question_display, info_text, question_slider] ) next_btn.click( fn=next_question, inputs=[dataset_dropdown, question_slider], outputs=[question_display, info_text, question_slider] ) random_btn.click( fn=random_question, inputs=[dataset_dropdown], outputs=[question_display, info_text, question_slider] ) # Load first question on start app.load( fn=show_question, inputs=[dataset_dropdown, question_slider], outputs=[question_display, info_text] ) # Search Tab with gr.Tab("🔍 Search"): with gr.Row(): search_query = gr.Textbox( label="Search Query", placeholder="Enter keywords (e.g., 'diabetes', 'heart failure', 'treatment')...", scale=3 ) search_dataset = gr.Dropdown( choices=['all', 'medgemini', 'medqa_train', 'medqa_dev', 'medqa_test'], value='all', label="Search In", scale=1 ) search_btn = gr.Button("🔍 Search", variant="primary") search_results = gr.HTML() search_btn.click( fn=search_interface, inputs=[search_query, search_dataset], outputs=[search_results] ) # Also search on Enter key search_query.submit( fn=search_interface, inputs=[search_query, search_dataset], outputs=[search_results] ) gr.Markdown(""" --- ### 📚 About the Databases **Med-Gemini**: Expert-relabeled medical questions with detailed explanations from Google's Med-Gemini project. **MedQA**: Original USMLE-style medical questions from the MedQA dataset. ### 🔗 Sources - [Med-Gemini Paper](https://arxiv.org/abs/2404.18416) - [MedQA Dataset](https://github.com/jind11/MedQA) """) if __name__ == "__main__": app.launch()