Spaces:
Sleeping
Sleeping
File size: 2,784 Bytes
847b628 a9d88a0 847b628 f944c35 847b628 534ed03 847b628 534ed03 a9d88a0 534ed03 847b628 15c714e 847b628 534ed03 847b628 534ed03 847b628 f944c35 3a8e6a8 534ed03 3a8e6a8 f944c35 3a8e6a8 15c714e 3a8e6a8 15c714e f944c35 15c714e 534ed03 15c714e 6c121ca f944c35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import os
import zipfile
import chromadb
from sentence_transformers import SentenceTransformer
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
# Extract database
DB_PATH = "./medqa_db"
if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
print("Extracting database...")
with zipfile.ZipFile("./medqa_db.zip", 'r') as zip_ref:
zip_ref.extractall(".")
print("Extracted!")
# Load database and model
print(f"Loading database from: {DB_PATH}")
client = chromadb.PersistentClient(path=DB_PATH)
collection = client.get_collection("medqa")
print(f"Loaded {collection.count()} items")
print("Loading model...")
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
print("Ready!")
# Search function
def search_function(query, num_results=3):
embedding = model.encode(query).tolist()
results = collection.query(query_embeddings=[embedding], n_results=int(num_results))
return results
# Gradio interface
def search_gradio(query, num_results=3):
if not query.strip():
return "Please enter a search query."
try:
results = search_function(query, num_results)
output = ""
for i in range(len(results['documents'][0])):
output += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
output += results['documents'][0][i] + "\n"
output += f"\nAnswer: {results['metadatas'][0][i].get('answer', 'N/A')}\n"
output += f"Similarity: {1 - results['distances'][0][i]:.3f}\n"
return output
except Exception as e:
return f"Error: {str(e)}"
demo = gr.Interface(
fn=search_gradio,
inputs=[
gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia", lines=2),
gr.Slider(1, 5, value=3, step=1, label="Number of Results")
],
outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
title="MedQA Search",
description="Search for similar USMLE Step 1 questions",
examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
)
# FastAPI for ChatGPT
app = FastAPI()
class SearchRequest(BaseModel):
query: str
num_results: int = 3
class SearchResponse(BaseModel):
results: list[dict]
@app.post("/search_medqa")
async def api_search(request: SearchRequest):
results = search_function(request.query, request.num_results)
formatted = []
for i in range(len(results['documents'][0])):
formatted.append({
"example_number": i + 1,
"question": results['documents'][0][i],
"answer": results['metadatas'][0][i].get('answer', 'N/A'),
"distance": results['distances'][0][i]
})
return SearchResponse(results=formatted)
# Mount Gradio on FastAPI
app = gr.mount_gradio_app(app, demo, path="/") |