# app.py - Hugging Face Spaces version - FIXED import os import zipfile from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import chromadb from sentence_transformers import SentenceTransformer import gradio as gr # Database path DB_PATH = "./medqa_db" ZIP_PATH = "./medqa_db.zip" # Extract database if needed if not os.path.exists(DB_PATH) and os.path.exists(ZIP_PATH): print("Extracting database from zip file...") with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref: zip_ref.extractall(".") print("Database extracted successfully!") # Initialize print(f"Loading database from: {DB_PATH}") client = chromadb.PersistentClient(path=DB_PATH) collection = client.get_collection("medqa") print(f"Collection loaded with {collection.count()} items") print(f"Loading MedCPT model...") model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') print("Initialization complete!") # Gradio interface function def search_interface(query: str, num_results: int = 3): """Simple web interface for testing""" if not query.strip(): return "Please enter a search query." try: embedding = model.encode(query).tolist() results = collection.query( query_embeddings=[embedding], n_results=int(num_results) ) if not results['documents'][0]: return "No results found." output = "" for i in range(len(results['documents'][0])): output += f"\n{'='*60}\n" output += f"Example {i+1}\n" output += f"{'='*60}\n" output += results['documents'][0][i] + "\n" output += f"\nAnswer: {results['metadatas'][0][i].get('answer', 'N/A')}\n" output += f"Similarity: {1 - results['distances'][0][i]:.3f}\n" return output except Exception as e: return f"Error: {str(e)}" # Create Gradio interface with gr.Blocks(title="MedQA Search") as demo: gr.Markdown("# MedQA Search - USMLE Question Database") gr.Markdown("Search for similar USMLE Step 1 questions using semantic similarity") with gr.Row(): with gr.Column(): query_input = gr.Textbox( label="Medical Topic or Clinical Scenario", placeholder="e.g., hyponatremia", lines=2 ) num_results_slider = gr.Slider( minimum=1, maximum=5, value=3, step=1, label="Number of Examples" ) search_btn = gr.Button("Search", variant="primary") with gr.Column(): output_text = gr.Textbox( label="Similar USMLE Questions", lines=25, max_lines=50 ) search_btn.click( fn=search_interface, inputs=[query_input, num_results_slider], outputs=output_text ) gr.Examples( examples=[ ["hyponatremia", 3], ["myocardial infarction", 2], ["diabetic ketoacidosis", 3] ], inputs=[query_input, num_results_slider] ) # FastAPI for API endpoints app = FastAPI(title="MedQA Search API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class SearchRequest(BaseModel): query: str num_results: int = 3 class SearchResponse(BaseModel): results: list[dict] @app.get("/") async def root(): return { "message": "MedQA Search API - Hugging Face Version", "status": "running", "collection_count": collection.count() } @app.post("/search_medqa", response_model=SearchResponse) async def search_medqa(request: SearchRequest): """Search MedQA database for similar USMLE questions""" try: embedding = model.encode(request.query).tolist() results = collection.query( query_embeddings=[embedding], n_results=request.num_results ) formatted_results = [] for i in range(len(results['documents'][0])): formatted_results.append({ "example_number": i + 1, "question": results['documents'][0][i], "answer": results['metadatas'][0][i].get('answer', 'N/A'), "distance": results['distances'][0][i] if 'distances' in results else None }) return Sea