File size: 3,579 Bytes
3a8e6a8
847b628
a9d88a0
847b628
 
 
 
 
 
 
15c714e
847b628
a9d88a0
 
 
 
 
 
 
 
847b628
15c714e
847b628
 
 
a9d88a0
15c714e
847b628
 
 
15c714e
 
3a8e6a8
15c714e
3a8e6a8
 
15c714e
3a8e6a8
 
 
15c714e
3a8e6a8
 
 
 
 
 
 
15c714e
847b628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a8e6a8
 
 
 
 
15c714e
3a8e6a8
 
 
 
 
 
 
 
15c714e
3a8e6a8
 
 
 
 
 
 
15c714e
3a8e6a8
6c121ca
 
 
 
15c714e
 
 
 
 
 
 
 
 
 
 
 
6c121ca
15c714e
6c121ca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# app.py - Hugging Face Spaces version - FIXED
import os
import zipfile
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import chromadb
from sentence_transformers import SentenceTransformer
import gradio as gr

# Database setup
DB_PATH = "./medqa_db"
ZIP_PATH = "./medqa_db.zip"

# Extract database if needed
if not os.path.exists(DB_PATH) and os.path.exists(ZIP_PATH):
    print("Extracting database from zip file...")
    with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall(".")
    print("Database extracted successfully!")

# Load database and model
print(f"Loading database from: {DB_PATH}")
client = chromadb.PersistentClient(path=DB_PATH)
collection = client.get_collection("medqa")
print(f"Collection loaded with {collection.count()} items")
print("Loading MedCPT model...")
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
print("Initialization complete!")

# Search function for Gradio
def search_gradio(query, num_results=3):
    if not query.strip():
        return "Please enter a query."
    try:
        embedding = model.encode(query).tolist()
        results = collection.query(query_embeddings=[embedding], n_results=int(num_results))
        
        output = ""
        for i in range(len(results['documents'][0])):
            output += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
            output += results['documents'][0][i] + "\n"
            output += f"\nAnswer: {results['metadatas'][0][i].get('answer', 'N/A')}\n"
            output += f"Similarity: {1 - results['distances'][0][i]:.3f}\n"
        return output
    except Exception as e:
        return f"Error: {str(e)}"

# FastAPI app
app = FastAPI(title="MedQA Search API")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class SearchRequest(BaseModel):
    query: str
    num_results: int = 3

class SearchResponse(BaseModel):
    results: list[dict]

@app.get("/")
async def root():
    return {
        "message": "MedQA Search API",
        "status": "running",
        "collection_count": collection.count()
    }

@app.post("/search_medqa", response_model=SearchResponse)
async def search_medqa(request: SearchRequest):
    try:
        embedding = model.encode(request.query).tolist()
        results = collection.query(query_embeddings=[embedding], n_results=request.num_results)
        
        formatted_results = []
        for i in range(len(results['documents'][0])):
            formatted_results.append({
                "example_number": i + 1,
                "question": results['documents'][0][i],
                "answer": results['metadatas'][0][i].get('answer', 'N/A'),
                "distance": results['distances'][0][i]
            })
        return SearchResponse(results=formatted_results)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Gradio interface
demo = gr.Interface(
    fn=search_gradio,
    inputs=[
        gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia", lines=2),
        gr.Slider(1, 5, value=3, step=1, label="Number of Results")
    ],
    outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
    title="MedQA Search - USMLE Question Database",
    description="Search for similar USMLE Step 1 questions using semantic similarity",
    examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
)

# Mount Gradio to FastAPI
app = gr.mount_gradio_app(app, demo, path="/")