# app.py - Hugging Face Spaces version - FIXED import os import zipfile from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import chromadb from sentence_transformers import SentenceTransformer import gradio as gr # Database setup DB_PATH = "./medqa_db" ZIP_PATH = "./medqa_db.zip" # Extract database if needed if not os.path.exists(DB_PATH) and os.path.exists(ZIP_PATH): print("Extracting database from zip file...") with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref: zip_ref.extractall(".") print("Database extracted successfully!") # Load database and model print(f"Loading database from: {DB_PATH}") client = chromadb.PersistentClient(path=DB_PATH) collection = client.get_collection("medqa") print(f"Collection loaded with {collection.count()} items") print("Loading MedCPT model...") model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') print("Initialization complete!") # Search function for Gradio def search_gradio(query, num_results=3): if not query.strip(): return "Please enter a query." try: embedding = model.encode(query).tolist() results = collection.query(query_embeddings=[embedding], n_results=int(num_results)) output = "" for i in range(len(results['documents'][0])): output += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n" output += results['documents'][0][i] + "\n" output += f"\nAnswer: {results['metadatas'][0][i].get('answer', 'N/A')}\n" output += f"Similarity: {1 - results['distances'][0][i]:.3f}\n" return output except Exception as e: return f"Error: {str(e)}" # FastAPI app app = FastAPI(title="MedQA Search API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class SearchRequest(BaseModel): query: str num_results: int = 3 class SearchResponse(BaseModel): results: list[dict] @app.get("/") async def root(): return { "message": "MedQA Search API", "status": "running", "collection_count": collection.count() } @app.post("/search_medqa", response_model=SearchResponse) async def search_medqa(request: SearchRequest): try: embedding = model.encode(request.query).tolist() results = collection.query(query_embeddings=[embedding], n_results=request.num_results) formatted_results = [] for i in range(len(results['documents'][0])): formatted_results.append({ "example_number": i + 1, "question": results['documents'][0][i], "answer": results['metadatas'][0][i].get('answer', 'N/A'), "distance": results['distances'][0][i] }) return SearchResponse(results=formatted_results) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Gradio interface demo = gr.Interface( fn=search_gradio, inputs=[ gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia", lines=2), gr.Slider(1, 5, value=3, step=1, label="Number of Results") ], outputs=gr.Textbox(label="Similar USMLE Questions", lines=20), title="MedQA Search - USMLE Question Database", description="Search for similar USMLE Step 1 questions using semantic similarity", examples=[["hyponatremia", 3], ["myocardial infarction", 2]] ) # Mount Gradio to FastAPI app = gr.mount_gradio_app(app, demo, path="/")