import gradio as gr import json import zipfile from pathlib import Path import pandas as pd from typing import Dict, List, Tuple import random class MedQADatabase: """Handler for MedQA and Med-Gemini databases""" def __init__(self, zip_path="medqa_databases.zip"): self.data = { 'medgemini': [], 'medqa_train': [], 'medqa_dev': [], 'medqa_test': [] } self.load_databases(zip_path) def load_databases(self, zip_path): """Load all databases from the ZIP file""" print("📦 Loading databases from ZIP...") try: with zipfile.ZipFile(zip_path, 'r') as zip_ref: # Extract to temporary directory zip_ref.extractall('temp_data') # Load Med-Gemini medgemini_path = Path('temp_data/medqa_databases/med_gemini/medqa_relabelling.json') if medgemini_path.exists(): with open(medgemini_path, 'r', encoding='utf-8') as f: self.data['medgemini'] = json.load(f) print(f"✅ Loaded {len(self.data['medgemini'])} Med-Gemini questions") # Load MedQA splits medqa_base = Path('temp_data/medqa_databases/medqa_original') for split in ['train', 'dev', 'test']: split_path = medqa_base / f"{split}.json" if split_path.exists(): with open(split_path, 'r', encoding='utf-8') as f: self.data[f'medqa_{split}'] = json.load(f) print(f"✅ Loaded {len(self.data[f'medqa_{split}'])} MedQA {split} questions") except Exception as e: print(f"❌ Error loading databases: {e}") raise def get_stats(self) -> str: """Get database statistics""" stats = "## 📊 Database Statistics\n\n" stats += f"**Med-Gemini**: {len(self.data['medgemini']):,} questions\n\n" stats += f"**MedQA Original**:\n" stats += f"- Training: {len(self.data['medqa_train']):,} questions\n" stats += f"- Development: {len(self.data['medqa_dev']):,} questions\n" stats += f"- Test: {len(self.data['medqa_test']):,} questions\n" stats += f"- **Total**: {sum(len(self.data[f'medqa_{s}']) for s in ['train', 'dev', 'test']):,} questions\n\n" stats += f"**Grand Total**: {sum(len(v) for v in self.data.values()):,} questions" return stats def get_question(self, dataset: str, index: int) -> Dict: """Get a specific question from a dataset""" try: return self.data[dataset][index] except (KeyError, IndexError): return None def search_questions(self, query: str, dataset: str = 'all', max_results: int = 50) -> List[Tuple[str, int, str]]: """Search questions by keyword""" results = [] query_lower = query.lower() datasets_to_search = list(self.data.keys()) if dataset == 'all' else [dataset] for ds in datasets_to_search: for idx, q in enumerate(self.data[ds]): # Search in question text question_text = q.get('question', q.get('Question', '')) if query_lower in question_text.lower(): preview = question_text[:100] + "..." if len(question_text) > 100 else question_text results.append((ds, idx, preview)) if len(results) >= max_results: return results return results # Initialize database print("🚀 Initializing MedQA Explorer...") db = MedQADatabase() # ============================================================================ # GRADIO INTERFACE FUNCTIONS # ============================================================================ def format_question_display(question_data: Dict, dataset: str) -> str: """Format question data for display""" if not question_data: return "❌ Question not found" # Handle different data formats if dataset == 'medgemini': return format_medgemini_question(question_data) else: return format_medqa_question(question_data) def format_medgemini_question(q: Dict) -> str: """Format Med-Gemini question""" html = f"""
{q.get('question', 'N/A')}
{correct_answer}
{explanation}
{q.get('question', 'N/A')}
{correct_answer}
{', '.join(metamap)}
Found {len(results)} results in {dataset}
{dataset_name} - Question #{idx + 1}
{preview}
... and {len(results) - 20} more results
" return html # ============================================================================ # GRADIO APP # ============================================================================ with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Database Explorer") as app: gr.Markdown(""" # 🏥 MedQA Database Explorer Explore medical question-answering databases including **Med-Gemini** and **MedQA USMLE**. """) # Statistics with gr.Accordion("📊 Database Statistics", open=False): gr.Markdown(db.get_stats()) # Main interface with gr.Tabs(): # Browse Tab with gr.Tab("📖 Browse Questions"): with gr.Row(): with gr.Column(scale=1): dataset_dropdown = gr.Dropdown( choices=['medgemini', 'medqa_train', 'medqa_dev', 'medqa_test'], value='medgemini', label="Select Database" ) question_slider = gr.Slider( minimum=0, maximum=len(db.data['medgemini']) - 1, value=0, step=1, label="Question Number" ) with gr.Row(): prev_btn = gr.Button("⬅️ Previous", size="sm") random_btn = gr.Button("🎲 Random", size="sm", variant="primary") next_btn = gr.Button("Next ➡️", size="sm") info_text = gr.Textbox(label="Info", interactive=False) with gr.Column(scale=2): question_display = gr.HTML() # Update slider max when dataset changes def update_slider(dataset): max_val = len(db.data.get(dataset, [])) - 1 return gr.Slider(maximum=max_val, value=0) dataset_dropdown.change( fn=update_slider, inputs=[dataset_dropdown], outputs=[question_slider] ) # Browse functions def show_question(dataset, index): return browse_questions(dataset, int(index)) question_slider.change( fn=show_question, inputs=[dataset_dropdown, question_slider], outputs=[question_display, info_text] ) dataset_dropdown.change( fn=show_question, inputs=[dataset_dropdown, question_slider], outputs=[question_display, info_text] ) # Navigation buttons def prev_question(dataset, index): new_index = max(0, int(index) - 1) html, info = browse_questions(dataset, new_index) return html, info, new_index def next_question(dataset, index): max_idx = len(db.data.get(dataset, [])) - 1 new_index = min(max_idx, int(index) + 1) html, info = browse_questions(dataset, new_index) return html, info, new_index prev_btn.click( fn=prev_question, inputs=[dataset_dropdown, question_slider], outputs=[question_display, info_text, question_slider] ) next_btn.click( fn=next_question, inputs=[dataset_dropdown, question_slider], outputs=[question_display, info_text, question_slider] ) random_btn.click( fn=random_question, inputs=[dataset_dropdown], outputs=[question_display, info_text, question_slider] ) # Load first question on start app.load( fn=show_question, inputs=[dataset_dropdown, question_slider], outputs=[question_display, info_text] ) # Search Tab with gr.Tab("🔍 Search"): with gr.Row(): search_query = gr.Textbox( label="Search Query", placeholder="Enter keywords (e.g., 'diabetes', 'heart failure', 'treatment')...", scale=3 ) search_dataset = gr.Dropdown( choices=['all', 'medgemini', 'medqa_train', 'medqa_dev', 'medqa_test'], value='all', label="Search In", scale=1 ) search_btn = gr.Button("🔍 Search", variant="primary") search_results = gr.HTML() search_btn.click( fn=search_interface, inputs=[search_query, search_dataset], outputs=[search_results] ) # Also search on Enter key search_query.submit( fn=search_interface, inputs=[search_query, search_dataset], outputs=[search_results] ) gr.Markdown(""" --- ### 📚 About the Databases **Med-Gemini**: Expert-relabeled medical questions with detailed explanations from Google's Med-Gemini project. **MedQA**: Original USMLE-style medical questions from the MedQA dataset. ### 🔗 Sources - [Med-Gemini Paper](https://arxiv.org/abs/2404.18416) - [MedQA Dataset](https://github.com/jind11/MedQA) """) if __name__ == "__main__": app.launch()