# app.py - Hugging Face Spaces version import os from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import chromadb from sentence_transformers import SentenceTransformer import gradio as gr # Database path DB_PATH = "./medqa_db" # Initialize print(f"Loading database from: {DB_PATH}") client = chromadb.PersistentClient(path=DB_PATH) collection = client.get_collection("medqa") print(f"Loading MedCPT model...") model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') print("Initialization complete!") # FastAPI app app = FastAPI(title="MedQA Search API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class SearchRequest(BaseModel): query: str num_results: int = 3 class SearchResponse(BaseModel): results: list[dict] @app.get("/") async def root(): return { "message": "MedQA Search API - Hugging Face Version", "status": "running", "collection_count": collection.count() } @app.post("/search_medqa", response_model=SearchResponse) async def search_medqa(request: SearchRequest): """Search MedQA database for similar USMLE questions""" try: embedding = model.encode(request.query).tolist() results = collection.query( query_embeddings=[embedding], n_results=request.num_results ) formatted_results = [] for i in range(len(results['documents'][0])): formatted_results.append({ "example_number": i + 1, "question": results['documents'][0][i], "answer": results['metadatas'][0][i].get('answer', 'N/A'), "distance": results['distances'][0][i] if 'distances' in results else None }) return SearchResponse(results=formatted_results) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Gradio interface (optional - gives you a web UI) def search_interface(query: str, num_results: int = 3): """Simple web interface for testing""" try: embedding = model.encode(query).tolist() results = collection.query( query_embeddings=[embedding], n_results=num_results ) output = "" for i in range(len(results['documents'][0])): output += f"\n{'='*60}\n" output += f"Example {i+1}\n" output += f"{'='*60}\n" output += results['documents'][0][i] + "\n" output += f"\nAnswer: {results['metadatas'][0][i].get('answer', 'N/A')}\n" output += f"Similarity: {1 - results['distances'][0][i]:.3f}\n" return output except Exception as e: return f"Error: {str(e)}" # Create Gradio interface demo = gr.Interface( fn=search_interface, inputs=[ gr.Textbox(label="Medical Topic or Clinical Scenario", placeholder="e.g., hyponatremia"), gr.Slider(1, 5, value=3, step=1, label="Number of Examples") ], outputs=gr.Textbox(label="Similar USMLE Questions", lines=20), title="MedQA Search - USMLE Question Database", description="Search for similar USMLE Step 1 questions using semantic similarity" ) # Mount Gradio app and FastAPI app = gr.mount_gradio_app(app, demo, path="/") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)