File size: 3,856 Bytes
847b628
0ce8c2d
 
a9d88a0
847b628
 
 
f944c35
 
847b628
98cb2a6
847b628
534ed03
98cb2a6
 
847b628
 
 
 
 
f944c35
98cb2a6
 
 
f944c35
98cb2a6
 
3a8e6a8
98cb2a6
3a8e6a8
98cb2a6
 
 
 
9ef5e84
 
 
 
249c206
 
 
 
 
9ef5e84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98cb2a6
 
3a8e6a8
98cb2a6
3a8e6a8
15c714e
98cb2a6
15c714e
98cb2a6
 
15c714e
 
534ed03
15c714e
 
6c121ca
98cb2a6
f944c35
 
 
 
 
 
 
98cb2a6
 
 
 
 
 
 
 
f944c35
0ce8c2d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
os.environ['ANONYMIZED_TELEMETRY'] = 'False'

import zipfile
import chromadb
from sentence_transformers import SentenceTransformer
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel

# Extract and load database
DB_PATH = "./medqa_db"
if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
    with zipfile.ZipFile("./medqa_db.zip", 'r') as z:
        z.extractall(".")

client = chromadb.PersistentClient(path=DB_PATH)
collection = client.get_collection("medqa")
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')

# Search function
def search(query, num_results=3):
    emb = model.encode(query).tolist()
    return collection.query(query_embeddings=[emb], n_results=int(num_results))

# Gradio UI
def ui_search(query, num_results=3):
    if not query.strip():
        return "Enter a query"
    try:
        r = search(query, num_results)
        out = ""
        for i in range(len(r['documents'][0])):
            out += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
            
            # Get the full question text
            question_text = r['documents'][0][i]
            
            # DEBUG: Show raw text
            out += "DEBUG RAW TEXT:\n"
            out += repr(question_text[:500]) + "\n"
            out += "="*60 + "\n\n"
            
            # Parse question and answer choices
            import re
            # Look for answer choices pattern (A. or A) followed by text)
            lines = question_text.split('\n')
            question_part = []
            choices_part = []
            in_choices = False
            
            for line in lines:
                # Check if line starts with A-E followed by . or )
                if re.match(r'^[A-E][\.\)]', line.strip()):
                    in_choices = True
                    choices_part.append(line)
                elif in_choices:
                    # Continue collecting choices if they span multiple lines
                    if line.strip() and not re.match(r'^[A-E][\.\)]', line.strip()):
                        choices_part[-1] += " " + line.strip()
                    elif re.match(r'^[A-E][\.\)]', line.strip()):
                        choices_part.append(line)
                else:
                    question_part.append(line)
            
            # Display question
            out += '\n'.join(question_part).strip() + "\n\n"
            
            # Display choices if found
            if choices_part:
                out += "Answer Choices:\n"
                for choice in choices_part:
                    out += choice.strip() + "\n"
                out += "\n"
            
            out += f"Correct Answer: {r['metadatas'][0][i].get('answer', 'N/A')}\n"
            out += f"Similarity: {1 - r['distances'][0][i]:.3f}\n"
        return out
    except Exception as e:
        return f"Error: {e}"

demo = gr.Interface(
    fn=ui_search,
    inputs=[
        gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia"),
        gr.Slider(1, 5, value=3, step=1, label="Results")
    ],
    outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
    title="MedQA Search",
    examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
)

# FastAPI
app = FastAPI()

class SearchRequest(BaseModel):
    query: str
    num_results: int = 3

@app.post("/search_medqa")
def api_search(req: SearchRequest):
    r = search(req.query, req.num_results)
    return {"results": [{
        "example_number": i+1,
        "question": r['documents'][0][i],
        "answer": r['metadatas'][0][i].get('answer', 'N/A'),
        "distance": r['distances'][0][i]
    } for i in range(len(r['documents'][0]))]}

app = gr.mount_gradio_app(app, demo, path="/")

# Launch the server
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)