Spaces:
Sleeping
Sleeping
| import os | |
| os.environ['ANONYMIZED_TELEMETRY'] = 'False' | |
| import zipfile | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| import gradio as gr | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import re | |
| import anthropic # You'll need: pip install anthropic | |
| # OR if using OpenAI: import openai | |
| # Extract and load database | |
| DB_PATH = "./medqa_db" | |
| if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"): | |
| print("π¦ Extracting database...") | |
| with zipfile.ZipFile("./medqa_db.zip", 'r') as z: | |
| z.extractall(".") | |
| print("β Database extracted") | |
| print("π Loading ChromaDB...") | |
| client = chromadb.PersistentClient(path=DB_PATH) | |
| collection = client.get_collection("medqa") | |
| print(f"β Loaded {collection.count()} questions") | |
| print("π§ Loading MedCPT model...") | |
| model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') | |
| print("β Model ready") | |
| # Initialize AI client (choose one) | |
| # Option 1: Claude | |
| claude_client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) | |
| # Option 2: OpenAI (uncomment if using) | |
| # openai.api_key = os.environ.get("OPENAI_API_KEY") | |
| # ============================================================================ | |
| # Deduplication function (same as before) | |
| # ============================================================================ | |
| def deduplicate_results(results, target_count): | |
| if not results['documents'][0]: | |
| return results | |
| documents = results['documents'][0] | |
| metadatas = results['metadatas'][0] | |
| distances = results['distances'][0] | |
| selected_indices = [] | |
| for i in range(len(documents)): | |
| is_duplicate = False | |
| current_answer = metadatas[i].get('answer', '') | |
| for j in selected_indices: | |
| selected_answer = metadatas[j].get('answer', '') | |
| dist_diff = abs(distances[i] - distances[j]) | |
| if dist_diff < 0.08: | |
| is_duplicate = True | |
| break | |
| if current_answer == selected_answer and dist_diff < 0.15: | |
| is_duplicate = True | |
| break | |
| if not is_duplicate: | |
| selected_indices.append(i) | |
| if len(selected_indices) >= target_count: | |
| break | |
| return { | |
| 'documents': [[documents[i] for i in selected_indices]], | |
| 'metadatas': [[metadatas[i] for i in selected_indices]], | |
| 'distances': [[distances[i] for i in selected_indices]], | |
| 'ids': [[results['ids'][0][i] for i in selected_indices]] if 'ids' in results else None | |
| } | |
| # ============================================================================ | |
| # Search function (same as before) | |
| # ============================================================================ | |
| def search(query, num_results=3, source_filter=None): | |
| emb = model.encode(query).tolist() | |
| where_clause = None | |
| if source_filter and source_filter != "all": | |
| where_clause = {"source": source_filter} | |
| fetch_count = min(num_results * 4, 50) | |
| results = collection.query( | |
| query_embeddings=[emb], | |
| n_results=fetch_count, | |
| where=where_clause | |
| ) | |
| return deduplicate_results(results, num_results) | |
| # ============================================================================ | |
| # NEW: Parser to extract question structure | |
| # ============================================================================ | |
| def parse_question_document(doc_text, metadata): | |
| """Extract question and choices from document text.""" | |
| lines = doc_text.split('\n') | |
| question_lines = [] | |
| options_started = False | |
| options = {} | |
| for line in lines: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| option_match = re.match(r'^([A-E])[\.\)]\s*(.+)$', line) | |
| if option_match: | |
| options_started = True | |
| letter = option_match.group(1) | |
| text = option_match.group(2).strip() | |
| options[letter] = text | |
| elif not options_started: | |
| question_lines.append(line) | |
| question_text = ' '.join(question_lines).strip() | |
| answer_idx = metadata.get('answer_idx', 'N/A') | |
| return { | |
| 'question': question_text, | |
| 'choices': options, | |
| 'correct_answer': answer_idx | |
| } | |
| # ============================================================================ | |
| # NEW: AI generation functions | |
| # ============================================================================ | |
| def generate_choice_explanations(question, choices, correct_answer): | |
| """Generate explanations for why each choice is correct/wrong.""" | |
| choices_text = '\n'.join([f"{k}. {v}" for k, v in choices.items()]) | |
| prompt = f"""You are a medical educator. For this USMLE-style question, explain why EACH answer choice is correct or incorrect. | |
| QUESTION: | |
| {question} | |
| ANSWER CHOICES: | |
| {choices_text} | |
| CORRECT ANSWER: {correct_answer} | |
| Provide a 1-2 sentence explanation for EACH choice (A through E) explaining why it is correct or incorrect. Format as: | |
| A. [Choice text] - [Explanation] | |
| B. [Choice text] - [Explanation] | |
| C. [Choice text] - [Explanation] | |
| D. [Choice text] - [Explanation] | |
| E. [Choice text] - [Explanation]""" | |
| # Using Claude | |
| message = claude_client.messages.create( | |
| model="claude-sonnet-4-20250514", | |
| max_tokens=1000, | |
| messages=[{"role": "user", "content": prompt}] | |
| ) | |
| return message.content[0].text | |
| # OR using OpenAI (uncomment if using): | |
| # response = openai.ChatCompletion.create( | |
| # model="gpt-4", | |
| # messages=[{"role": "user", "content": prompt}], | |
| # max_tokens=1000 | |
| # ) | |
| # return response.choices[0].message.content | |
| def generate_similar_question(original_question, choices, correct_answer): | |
| """Generate a new question based on the exemplar.""" | |
| choices_text = '\n'.join([f"{k}. {v}" for k, v in choices.items()]) | |
| prompt = f"""You are a medical educator. Based on this USMLE-style question, create a NEW similar question that tests the SAME medical concept but with a different clinical scenario. | |
| ORIGINAL QUESTION: | |
| {question} | |
| ANSWER CHOICES: | |
| {choices_text} | |
| CORRECT ANSWER: {correct_answer} | |
| Create a NEW question that: | |
| 1. Tests the same medical concept | |
| 2. Uses a different patient scenario | |
| 3. Has 5 answer choices (A-E) | |
| 4. Includes explanations for why each choice is correct/incorrect | |
| Format your response EXACTLY as: | |
| NEW QUESTION: | |
| [Your new question text] | |
| ANSWER CHOICES: | |
| A. [Choice A] | |
| B. [Choice B] | |
| C. [Choice C] | |
| D. [Choice D] | |
| E. [Choice E] | |
| CORRECT ANSWER: [Letter] | |
| EXPLANATIONS: | |
| A. [Choice A text] - [Explanation] | |
| B. [Choice B text] - [Explanation] | |
| C. [Choice C text] - [Explanation] | |
| D. [Choice D text] - [Explanation] | |
| E. [Choice E text] - [Explanation]""" | |
| # Using Claude | |
| message = claude_client.messages.create( | |
| model="claude-sonnet-4-20250514", | |
| max_tokens=2000, | |
| messages=[{"role": "user", "content": prompt}] | |
| ) | |
| return message.content[0].text | |
| # OR using OpenAI: | |
| # response = openai.ChatCompletion.create( | |
| # model="gpt-4", | |
| # messages=[{"role": "user", "content": prompt}], | |
| # max_tokens=2000 | |
| # ) | |
| # return response.choices[0].message.content | |
| # ============================================================================ | |
| # NEW: Format complete output | |
| # ============================================================================ | |
| def format_complete_output(exemplar_num, parsed, original_explanation, choice_explanations, new_question_text): | |
| """Format everything into readable plain text.""" | |
| choices_text = '\n'.join([f"{k}. {v}" for k, v in parsed['choices'].items()]) | |
| output = f"""{'='*80} | |
| EXEMPLAR {exemplar_num} | |
| {'='*80} | |
| ORIGINAL QUESTION: | |
| {parsed['question']} | |
| ANSWER CHOICES: | |
| {choices_text} | |
| CORRECT ANSWER: {parsed['correct_answer']} | |
| EXPLANATION FOR EACH CHOICE: | |
| {choice_explanations} | |
| """ | |
| if original_explanation: | |
| output += f"\nORIGINAL EXPLANATION FROM DATABASE:\n{original_explanation}\n" | |
| output += f""" | |
| {'-'*80} | |
| AI-GENERATED SIMILAR QUESTION: | |
| {'-'*80} | |
| {new_question_text} | |
| {'='*80} | |
| """ | |
| return output | |
| # ============================================================================ | |
| # MODIFIED: API endpoint with full generation | |
| # ============================================================================ | |
| app = FastAPI() | |
| class SearchRequest(BaseModel): | |
| query: str | |
| num_results: int = 3 | |
| source_filter: str = None | |
| generate_ai: bool = True # Option to skip AI generation for faster response | |
| def api_search(req: SearchRequest): | |
| """Search and return complete formatted exemplars with AI-generated content.""" | |
| print(f"π Searching for: {req.query}") | |
| r = search(req.query, req.num_results, req.source_filter) | |
| if not r['documents'][0]: | |
| return {"output": "No results found."} | |
| complete_output = f"SEARCH QUERY: {req.query}\n" | |
| complete_output += f"FOUND {len(r['documents'][0])} EXEMPLARS\n\n" | |
| for i in range(len(r['documents'][0])): | |
| print(f"Processing exemplar {i+1}...") | |
| doc_text = r['documents'][0][i] | |
| metadata = r['metadatas'][0][i] | |
| # Parse the exemplar | |
| parsed = parse_question_document(doc_text, metadata) | |
| original_explanation = metadata.get('explanation', '') | |
| if req.generate_ai: | |
| # Generate AI content | |
| print(f" Generating choice explanations...") | |
| choice_explanations = generate_choice_explanations( | |
| parsed['question'], | |
| parsed['choices'], | |
| parsed['correct_answer'] | |
| ) | |
| print(f" Generating similar question...") | |
| new_question = generate_similar_question( | |
| parsed['question'], | |
| parsed['choices'], | |
| parsed['correct_answer'] | |
| ) | |
| else: | |
| choice_explanations = "(AI generation skipped)" | |
| new_question = "(AI generation skipped)" | |
| # Format complete output | |
| formatted = format_complete_output( | |
| i + 1, | |
| parsed, | |
| original_explanation, | |
| choice_explanations, | |
| new_question | |
| ) | |
| complete_output += formatted | |
| return { | |
| "output": complete_output, | |
| "content_type": "text/plain" | |
| } | |
| # Gradio UI (simplified - just shows we have it) | |
| with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Search") as demo: | |
| gr.Markdown("# π₯ MedQA Search with AI Generation") | |
| query_input = gr.Textbox(label="Query") | |
| output = gr.Textbox(label="Results", lines=50) | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |