File size: 2,784 Bytes
847b628
a9d88a0
847b628
 
 
f944c35
 
847b628
534ed03
847b628
534ed03
 
 
a9d88a0
534ed03
847b628
15c714e
847b628
 
 
534ed03
 
847b628
534ed03
847b628
f944c35
 
 
 
 
 
 
 
3a8e6a8
534ed03
 
3a8e6a8
f944c35
3a8e6a8
 
15c714e
3a8e6a8
 
 
 
 
 
 
15c714e
f944c35
15c714e
 
 
 
 
534ed03
 
15c714e
 
6c121ca
f944c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import zipfile
import chromadb
from sentence_transformers import SentenceTransformer
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel

# Extract database
DB_PATH = "./medqa_db"
if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
    print("Extracting database...")
    with zipfile.ZipFile("./medqa_db.zip", 'r') as zip_ref:
        zip_ref.extractall(".")
    print("Extracted!")

# Load database and model
print(f"Loading database from: {DB_PATH}")
client = chromadb.PersistentClient(path=DB_PATH)
collection = client.get_collection("medqa")
print(f"Loaded {collection.count()} items")
print("Loading model...")
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
print("Ready!")

# Search function
def search_function(query, num_results=3):
    embedding = model.encode(query).tolist()
    results = collection.query(query_embeddings=[embedding], n_results=int(num_results))
    return results

# Gradio interface
def search_gradio(query, num_results=3):
    if not query.strip():
        return "Please enter a search query."
    
    try:
        results = search_function(query, num_results)
        output = ""
        for i in range(len(results['documents'][0])):
            output += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
            output += results['documents'][0][i] + "\n"
            output += f"\nAnswer: {results['metadatas'][0][i].get('answer', 'N/A')}\n"
            output += f"Similarity: {1 - results['distances'][0][i]:.3f}\n"
        return output
    except Exception as e:
        return f"Error: {str(e)}"

demo = gr.Interface(
    fn=search_gradio,
    inputs=[
        gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia", lines=2),
        gr.Slider(1, 5, value=3, step=1, label="Number of Results")
    ],
    outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
    title="MedQA Search",
    description="Search for similar USMLE Step 1 questions",
    examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
)

# FastAPI for ChatGPT
app = FastAPI()

class SearchRequest(BaseModel):
    query: str
    num_results: int = 3

class SearchResponse(BaseModel):
    results: list[dict]

@app.post("/search_medqa")
async def api_search(request: SearchRequest):
    results = search_function(request.query, request.num_results)
    formatted = []
    for i in range(len(results['documents'][0])):
        formatted.append({
            "example_number": i + 1,
            "question": results['documents'][0][i],
            "answer": results['metadatas'][0][i].get('answer', 'N/A'),
            "distance": results['distances'][0][i]
        })
    return SearchResponse(results=formatted)

# Mount Gradio on FastAPI
app = gr.mount_gradio_app(app, demo, path="/")