Spaces:
Sleeping
Sleeping
File size: 2,080 Bytes
847b628 a9d88a0 847b628 f944c35 847b628 98cb2a6 847b628 534ed03 98cb2a6 847b628 f944c35 98cb2a6 f944c35 98cb2a6 3a8e6a8 98cb2a6 3a8e6a8 98cb2a6 3a8e6a8 98cb2a6 3a8e6a8 15c714e 98cb2a6 15c714e 98cb2a6 15c714e 534ed03 15c714e 6c121ca 98cb2a6 f944c35 98cb2a6 f944c35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import os
import zipfile
import chromadb
from sentence_transformers import SentenceTransformer
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
# Extract and load database
DB_PATH = "./medqa_db"
if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
with zipfile.ZipFile("./medqa_db.zip", 'r') as z:
z.extractall(".")
client = chromadb.PersistentClient(path=DB_PATH)
collection = client.get_collection("medqa")
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
# Search function
def search(query, num_results=3):
emb = model.encode(query).tolist()
return collection.query(query_embeddings=[emb], n_results=int(num_results))
# Gradio UI
def ui_search(query, num_results=3):
if not query.strip():
return "Enter a query"
try:
r = search(query, num_results)
out = ""
for i in range(len(r['documents'][0])):
out += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
out += r['documents'][0][i] + f"\n\nAnswer: {r['metadatas'][0][i].get('answer', 'N/A')}\n"
out += f"Similarity: {1 - r['distances'][0][i]:.3f}\n"
return out
except Exception as e:
return f"Error: {e}"
demo = gr.Interface(
fn=ui_search,
inputs=[
gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia"),
gr.Slider(1, 5, value=3, step=1, label="Results")
],
outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
title="MedQA Search",
examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
)
# FastAPI
app = FastAPI()
class SearchRequest(BaseModel):
query: str
num_results: int = 3
@app.post("/search_medqa")
def api_search(req: SearchRequest):
r = search(req.query, req.num_results)
return {"results": [{
"example_number": i+1,
"question": r['documents'][0][i],
"answer": r['metadatas'][0][i].get('answer', 'N/A'),
"distance": r['distances'][0][i]
} for i in range(len(r['documents'][0]))]}
app = gr.mount_gradio_app(app, demo, path="/") |