import os import zipfile import chromadb from sentence_transformers import SentenceTransformer import gradio as gr from fastapi import FastAPI from pydantic import BaseModel # Extract database DB_PATH = "./medqa_db" if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"): print("Extracting database...") with zipfile.ZipFile("./medqa_db.zip", 'r') as zip_ref: zip_ref.extractall(".") print("Extracted!") # Load database and model print(f"Loading database from: {DB_PATH}") client = chromadb.PersistentClient(path=DB_PATH) collection = client.get_collection("medqa") print(f"Loaded {collection.count()} items") print("Loading model...") model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') print("Ready!") # Search function def search_function(query, num_results=3): embedding = model.encode(query).tolist() results = collection.query(query_embeddings=[embedding], n_results=int(num_results)) return results # Gradio interface def search_gradio(query, num_results=3): if not query.strip(): return "Please enter a search query." try: results = search_function(query, num_results) output = "" for i in range(len(results['documents'][0])): output += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n" output += results['documents'][0][i] + "\n" output += f"\nAnswer: {results['metadatas'][0][i].get('answer', 'N/A')}\n" output += f"Similarity: {1 - results['distances'][0][i]:.3f}\n" return output except Exception as e: return f"Error: {str(e)}" demo = gr.Interface( fn=search_gradio, inputs=[ gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia", lines=2), gr.Slider(1, 5, value=3, step=1, label="Number of Results") ], outputs=gr.Textbox(label="Similar USMLE Questions", lines=20), title="MedQA Search", description="Search for similar USMLE Step 1 questions", examples=[["hyponatremia", 3], ["myocardial infarction", 2]] ) # FastAPI for ChatGPT app = FastAPI() class SearchRequest(BaseModel): query: str num_results: int = 3 class SearchResponse(BaseModel): results: list[dict] @app.post("/search_medqa") async def api_search(request: SearchRequest): results = search_function(request.query, request.num_results) formatted = [] for i in range(len(results['documents'][0])): formatted.append({ "example_number": i + 1, "question": results['documents'][0][i], "answer": results['metadatas'][0][i].get('answer', 'N/A'), "distance": results['distances'][0][i] }) return SearchResponse(results=formatted) # Mount Gradio on FastAPI app = gr.mount_gradio_app(app, demo, path="/")