File size: 2,415 Bytes
847b628
0ce8c2d
 
a9d88a0
847b628
 
 
f944c35
 
847b628
98cb2a6
847b628
534ed03
98cb2a6
 
847b628
 
 
 
 
f944c35
98cb2a6
 
 
f944c35
98cb2a6
 
3a8e6a8
98cb2a6
3a8e6a8
98cb2a6
85baae8
 
 
 
 
98cb2a6
 
 
 
 
 
3a8e6a8
98cb2a6
3a8e6a8
15c714e
98cb2a6
15c714e
98cb2a6
 
15c714e
 
534ed03
15c714e
 
6c121ca
98cb2a6
f944c35
 
 
 
 
 
 
98cb2a6
 
 
 
 
 
 
 
f944c35
0ce8c2d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
os.environ['ANONYMIZED_TELEMETRY'] = 'False'

import zipfile
import chromadb
from sentence_transformers import SentenceTransformer
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel

# Extract and load database
DB_PATH = "./medqa_db"
if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
    with zipfile.ZipFile("./medqa_db.zip", 'r') as z:
        z.extractall(".")

client = chromadb.PersistentClient(path=DB_PATH)
collection = client.get_collection("medqa")
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')

# Search function
def search(query, num_results=3):
    emb = model.encode(query).tolist()
    return collection.query(query_embeddings=[emb], n_results=int(num_results))

# Gradio UI
def ui_search(query, num_results=3):
    if not query.strip():
        return "Enter a query"
    try:
        r = search(query, num_results)
        
        # DEBUG: Show what's in metadata
        print("METADATA KEYS:", r['metadatas'][0][0].keys())
        print("FULL METADATA:", r['metadatas'][0][0])
        
        out = ""
        for i in range(len(r['documents'][0])):
            out += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
            out += r['documents'][0][i] + f"\n\nAnswer: {r['metadatas'][0][i].get('answer', 'N/A')}\n"
            out += f"Similarity: {1 - r['distances'][0][i]:.3f}\n"
        return out
    except Exception as e:
        return f"Error: {e}"

demo = gr.Interface(
    fn=ui_search,
    inputs=[
        gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia"),
        gr.Slider(1, 5, value=3, step=1, label="Results")
    ],
    outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
    title="MedQA Search",
    examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
)

# FastAPI
app = FastAPI()

class SearchRequest(BaseModel):
    query: str
    num_results: int = 3

@app.post("/search_medqa")
def api_search(req: SearchRequest):
    r = search(req.query, req.num_results)
    return {"results": [{
        "example_number": i+1,
        "question": r['documents'][0][i],
        "answer": r['metadatas'][0][i].get('answer', 'N/A'),
        "distance": r['distances'][0][i]
    } for i in range(len(r['documents'][0]))]}

app = gr.mount_gradio_app(app, demo, path="/")

# Launch the server
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)