Spaces:
Sleeping
Sleeping
| import os | |
| import zipfile | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| import gradio as gr | |
| # Extract database | |
| DB_PATH = "./medqa_db" | |
| if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"): | |
| print("Extracting database...") | |
| with zipfile.ZipFile("./medqa_db.zip", 'r') as zip_ref: | |
| zip_ref.extractall(".") | |
| print("Extracted!") | |
| # Load database and model | |
| print(f"Loading database from: {DB_PATH}") | |
| client = chromadb.PersistentClient(path=DB_PATH) | |
| collection = client.get_collection("medqa") | |
| print(f"Loaded {collection.count()} items") | |
| print("Loading model...") | |
| model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') | |
| print("Ready!") | |
| def search_medqa(query, num_results=3): | |
| if not query.strip(): | |
| return "Please enter a search query." | |
| try: | |
| embedding = model.encode(query).tolist() | |
| results = collection.query(query_embeddings=[embedding], n_results=int(num_results)) | |
| output = "" | |
| for i in range(len(results['documents'][0])): | |
| output += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n" | |
| output += results['documents'][0][i] + "\n" | |
| output += f"\nAnswer: {results['metadatas'][0][i].get('answer', 'N/A')}\n" | |
| output += f"Similarity: {1 - results['distances'][0][i]:.3f}\n" | |
| return output | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=search_medqa, | |
| inputs=[ | |
| gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia", lines=2), | |
| gr.Slider(1, 5, value=3, step=1, label="Number of Results") | |
| ], | |
| outputs=gr.Textbox(label="Similar USMLE Questions", lines=20), | |
| title="MedQA Search", | |
| description="Search for similar USMLE Step 1 questions", | |
| examples=[["hyponatremia", 3], ["myocardial infarction", 2]] | |
| ) | |
| # Launch - this is the key line! | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |