import os os.environ['ANONYMIZED_TELEMETRY'] = 'False' import zipfile import chromadb from sentence_transformers import SentenceTransformer import gradio as gr from fastapi import FastAPI from pydantic import BaseModel import re import anthropic # You'll need: pip install anthropic # OR if using OpenAI: import openai # Extract and load database DB_PATH = "./medqa_db" if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"): print("📦 Extracting database...") with zipfile.ZipFile("./medqa_db.zip", 'r') as z: z.extractall(".") print("✅ Database extracted") print("🔌 Loading ChromaDB...") client = chromadb.PersistentClient(path=DB_PATH) collection = client.get_collection("medqa") print(f"✅ Loaded {collection.count()} questions") print("🧠 Loading MedCPT model...") model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') print("✅ Model ready") # Initialize AI client (choose one) # Option 1: Claude claude_client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) # Option 2: OpenAI (uncomment if using) # openai.api_key = os.environ.get("OPENAI_API_KEY") # ============================================================================ # Deduplication function (same as before) # ============================================================================ def deduplicate_results(results, target_count): if not results['documents'][0]: return results documents = results['documents'][0] metadatas = results['metadatas'][0] distances = results['distances'][0] selected_indices = [] for i in range(len(documents)): is_duplicate = False current_answer = metadatas[i].get('answer', '') for j in selected_indices: selected_answer = metadatas[j].get('answer', '') dist_diff = abs(distances[i] - distances[j]) if dist_diff < 0.08: is_duplicate = True break if current_answer == selected_answer and dist_diff < 0.15: is_duplicate = True break if not is_duplicate: selected_indices.append(i) if len(selected_indices) >= target_count: break return { 'documents': [[documents[i] for i in selected_indices]], 'metadatas': [[metadatas[i] for i in selected_indices]], 'distances': [[distances[i] for i in selected_indices]], 'ids': [[results['ids'][0][i] for i in selected_indices]] if 'ids' in results else None } # ============================================================================ # Search function (same as before) # ============================================================================ def search(query, num_results=3, source_filter=None): emb = model.encode(query).tolist() where_clause = None if source_filter and source_filter != "all": where_clause = {"source": source_filter} fetch_count = min(num_results * 4, 50) results = collection.query( query_embeddings=[emb], n_results=fetch_count, where=where_clause ) return deduplicate_results(results, num_results) # ============================================================================ # NEW: Parser to extract question structure # ============================================================================ def parse_question_document(doc_text, metadata): """Extract question and choices from document text.""" lines = doc_text.split('\n') question_lines = [] options_started = False options = {} for line in lines: line = line.strip() if not line: continue option_match = re.match(r'^([A-E])[\.\)]\s*(.+)$', line) if option_match: options_started = True letter = option_match.group(1) text = option_match.group(2).strip() options[letter] = text elif not options_started: question_lines.append(line) question_text = ' '.join(question_lines).strip() answer_idx = metadata.get('answer_idx', 'N/A') return { 'question': question_text, 'choices': options, 'correct_answer': answer_idx } # ============================================================================ # NEW: AI generation functions # ============================================================================ def generate_choice_explanations(question, choices, correct_answer): """Generate explanations for why each choice is correct/wrong.""" choices_text = '\n'.join([f"{k}. {v}" for k, v in choices.items()]) prompt = f"""You are a medical educator. For this USMLE-style question, explain why EACH answer choice is correct or incorrect. QUESTION: {question} ANSWER CHOICES: {choices_text} CORRECT ANSWER: {correct_answer} Provide a 1-2 sentence explanation for EACH choice (A through E) explaining why it is correct or incorrect. Format as: A. [Choice text] - [Explanation] B. [Choice text] - [Explanation] C. [Choice text] - [Explanation] D. [Choice text] - [Explanation] E. [Choice text] - [Explanation]""" # Using Claude message = claude_client.messages.create( model="claude-sonnet-4-20250514", max_tokens=1000, messages=[{"role": "user", "content": prompt}] ) return message.content[0].text # OR using OpenAI (uncomment if using): # response = openai.ChatCompletion.create( # model="gpt-4", # messages=[{"role": "user", "content": prompt}], # max_tokens=1000 # ) # return response.choices[0].message.content def generate_similar_question(original_question, choices, correct_answer): """Generate a new question based on the exemplar.""" choices_text = '\n'.join([f"{k}. {v}" for k, v in choices.items()]) prompt = f"""You are a medical educator. Based on this USMLE-style question, create a NEW similar question that tests the SAME medical concept but with a different clinical scenario. ORIGINAL QUESTION: {question} ANSWER CHOICES: {choices_text} CORRECT ANSWER: {correct_answer} Create a NEW question that: 1. Tests the same medical concept 2. Uses a different patient scenario 3. Has 5 answer choices (A-E) 4. Includes explanations for why each choice is correct/incorrect Format your response EXACTLY as: NEW QUESTION: [Your new question text] ANSWER CHOICES: A. [Choice A] B. [Choice B] C. [Choice C] D. [Choice D] E. [Choice E] CORRECT ANSWER: [Letter] EXPLANATIONS: A. [Choice A text] - [Explanation] B. [Choice B text] - [Explanation] C. [Choice C text] - [Explanation] D. [Choice D text] - [Explanation] E. [Choice E text] - [Explanation]""" # Using Claude message = claude_client.messages.create( model="claude-sonnet-4-20250514", max_tokens=2000, messages=[{"role": "user", "content": prompt}] ) return message.content[0].text # OR using OpenAI: # response = openai.ChatCompletion.create( # model="gpt-4", # messages=[{"role": "user", "content": prompt}], # max_tokens=2000 # ) # return response.choices[0].message.content # ============================================================================ # NEW: Format complete output # ============================================================================ def format_complete_output(exemplar_num, parsed, original_explanation, choice_explanations, new_question_text): """Format everything into readable plain text.""" choices_text = '\n'.join([f"{k}. {v}" for k, v in parsed['choices'].items()]) output = f"""{'='*80} EXEMPLAR {exemplar_num} {'='*80} ORIGINAL QUESTION: {parsed['question']} ANSWER CHOICES: {choices_text} CORRECT ANSWER: {parsed['correct_answer']} EXPLANATION FOR EACH CHOICE: {choice_explanations} """ if original_explanation: output += f"\nORIGINAL EXPLANATION FROM DATABASE:\n{original_explanation}\n" output += f""" {'-'*80} AI-GENERATED SIMILAR QUESTION: {'-'*80} {new_question_text} {'='*80} """ return output # ============================================================================ # MODIFIED: API endpoint with full generation # ============================================================================ app = FastAPI() class SearchRequest(BaseModel): query: str num_results: int = 3 source_filter: str = None generate_ai: bool = True # Option to skip AI generation for faster response @app.post("/search_medqa") def api_search(req: SearchRequest): """Search and return complete formatted exemplars with AI-generated content.""" print(f"🔍 Searching for: {req.query}") r = search(req.query, req.num_results, req.source_filter) if not r['documents'][0]: return {"output": "No results found."} complete_output = f"SEARCH QUERY: {req.query}\n" complete_output += f"FOUND {len(r['documents'][0])} EXEMPLARS\n\n" for i in range(len(r['documents'][0])): print(f"Processing exemplar {i+1}...") doc_text = r['documents'][0][i] metadata = r['metadatas'][0][i] # Parse the exemplar parsed = parse_question_document(doc_text, metadata) original_explanation = metadata.get('explanation', '') if req.generate_ai: # Generate AI content print(f" Generating choice explanations...") choice_explanations = generate_choice_explanations( parsed['question'], parsed['choices'], parsed['correct_answer'] ) print(f" Generating similar question...") new_question = generate_similar_question( parsed['question'], parsed['choices'], parsed['correct_answer'] ) else: choice_explanations = "(AI generation skipped)" new_question = "(AI generation skipped)" # Format complete output formatted = format_complete_output( i + 1, parsed, original_explanation, choice_explanations, new_question ) complete_output += formatted return { "output": complete_output, "content_type": "text/plain" } # Gradio UI (simplified - just shows we have it) with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Search") as demo: gr.Markdown("# 🏥 MedQA Search with AI Generation") query_input = gr.Textbox(label="Query") output = gr.Textbox(label="Results", lines=50) app = gr.mount_gradio_app(app, demo, path="/") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)