Spaces:
Sleeping
Sleeping
| import os | |
| import zipfile | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| import gradio as gr | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| # Extract and load database | |
| DB_PATH = "./medqa_db" | |
| if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"): | |
| with zipfile.ZipFile("./medqa_db.zip", 'r') as z: | |
| z.extractall(".") | |
| client = chromadb.PersistentClient(path=DB_PATH) | |
| collection = client.get_collection("medqa") | |
| model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') | |
| # Search function | |
| def search(query, num_results=3): | |
| emb = model.encode(query).tolist() | |
| return collection.query(query_embeddings=[emb], n_results=int(num_results)) | |
| # Gradio UI | |
| def ui_search(query, num_results=3): | |
| if not query.strip(): | |
| return "Enter a query" | |
| try: | |
| r = search(query, num_results) | |
| out = "" | |
| for i in range(len(r['documents'][0])): | |
| out += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n" | |
| out += r['documents'][0][i] + f"\n\nAnswer: {r['metadatas'][0][i].get('answer', 'N/A')}\n" | |
| out += f"Similarity: {1 - r['distances'][0][i]:.3f}\n" | |
| return out | |
| except Exception as e: | |
| return f"Error: {e}" | |
| demo = gr.Interface( | |
| fn=ui_search, | |
| inputs=[ | |
| gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia"), | |
| gr.Slider(1, 5, value=3, step=1, label="Results") | |
| ], | |
| outputs=gr.Textbox(label="Similar USMLE Questions", lines=20), | |
| title="MedQA Search", | |
| examples=[["hyponatremia", 3], ["myocardial infarction", 2]] | |
| ) | |
| # FastAPI | |
| app = FastAPI() | |
| class SearchRequest(BaseModel): | |
| query: str | |
| num_results: int = 3 | |
| def api_search(req: SearchRequest): | |
| r = search(req.query, req.num_results) | |
| return {"results": [{ | |
| "example_number": i+1, | |
| "question": r['documents'][0][i], | |
| "answer": r['metadatas'][0][i].get('answer', 'N/A'), | |
| "distance": r['distances'][0][i] | |
| } for i in range(len(r['documents'][0]))]} | |
| app = gr.mount_gradio_app(app, demo, path="/") |