File size: 2,080 Bytes
847b628
a9d88a0
847b628
 
 
f944c35
 
847b628
98cb2a6
847b628
534ed03
98cb2a6
 
847b628
 
 
 
 
f944c35
98cb2a6
 
 
f944c35
98cb2a6
 
3a8e6a8
98cb2a6
3a8e6a8
98cb2a6
 
 
 
 
 
 
3a8e6a8
98cb2a6
3a8e6a8
15c714e
98cb2a6
15c714e
98cb2a6
 
15c714e
 
534ed03
15c714e
 
6c121ca
98cb2a6
f944c35
 
 
 
 
 
 
98cb2a6
 
 
 
 
 
 
 
f944c35
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import zipfile
import chromadb
from sentence_transformers import SentenceTransformer
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel

# Extract and load database
DB_PATH = "./medqa_db"
if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
    with zipfile.ZipFile("./medqa_db.zip", 'r') as z:
        z.extractall(".")

client = chromadb.PersistentClient(path=DB_PATH)
collection = client.get_collection("medqa")
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')

# Search function
def search(query, num_results=3):
    emb = model.encode(query).tolist()
    return collection.query(query_embeddings=[emb], n_results=int(num_results))

# Gradio UI
def ui_search(query, num_results=3):
    if not query.strip():
        return "Enter a query"
    try:
        r = search(query, num_results)
        out = ""
        for i in range(len(r['documents'][0])):
            out += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
            out += r['documents'][0][i] + f"\n\nAnswer: {r['metadatas'][0][i].get('answer', 'N/A')}\n"
            out += f"Similarity: {1 - r['distances'][0][i]:.3f}\n"
        return out
    except Exception as e:
        return f"Error: {e}"

demo = gr.Interface(
    fn=ui_search,
    inputs=[
        gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia"),
        gr.Slider(1, 5, value=3, step=1, label="Results")
    ],
    outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
    title="MedQA Search",
    examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
)

# FastAPI
app = FastAPI()

class SearchRequest(BaseModel):
    query: str
    num_results: int = 3

@app.post("/search_medqa")
def api_search(req: SearchRequest):
    r = search(req.query, req.num_results)
    return {"results": [{
        "example_number": i+1,
        "question": r['documents'][0][i],
        "answer": r['metadatas'][0][i].get('answer', 'N/A'),
        "distance": r['distances'][0][i]
    } for i in range(len(r['documents'][0]))]}

app = gr.mount_gradio_app(app, demo, path="/")