imurra's picture
Update app.py
f944c35 verified
raw
history blame
2.78 kB
import os
import zipfile
import chromadb
from sentence_transformers import SentenceTransformer
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
# Extract database
DB_PATH = "./medqa_db"
if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
print("Extracting database...")
with zipfile.ZipFile("./medqa_db.zip", 'r') as zip_ref:
zip_ref.extractall(".")
print("Extracted!")
# Load database and model
print(f"Loading database from: {DB_PATH}")
client = chromadb.PersistentClient(path=DB_PATH)
collection = client.get_collection("medqa")
print(f"Loaded {collection.count()} items")
print("Loading model...")
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
print("Ready!")
# Search function
def search_function(query, num_results=3):
embedding = model.encode(query).tolist()
results = collection.query(query_embeddings=[embedding], n_results=int(num_results))
return results
# Gradio interface
def search_gradio(query, num_results=3):
if not query.strip():
return "Please enter a search query."
try:
results = search_function(query, num_results)
output = ""
for i in range(len(results['documents'][0])):
output += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
output += results['documents'][0][i] + "\n"
output += f"\nAnswer: {results['metadatas'][0][i].get('answer', 'N/A')}\n"
output += f"Similarity: {1 - results['distances'][0][i]:.3f}\n"
return output
except Exception as e:
return f"Error: {str(e)}"
demo = gr.Interface(
fn=search_gradio,
inputs=[
gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia", lines=2),
gr.Slider(1, 5, value=3, step=1, label="Number of Results")
],
outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
title="MedQA Search",
description="Search for similar USMLE Step 1 questions",
examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
)
# FastAPI for ChatGPT
app = FastAPI()
class SearchRequest(BaseModel):
query: str
num_results: int = 3
class SearchResponse(BaseModel):
results: list[dict]
@app.post("/search_medqa")
async def api_search(request: SearchRequest):
results = search_function(request.query, request.num_results)
formatted = []
for i in range(len(results['documents'][0])):
formatted.append({
"example_number": i + 1,
"question": results['documents'][0][i],
"answer": results['metadatas'][0][i].get('answer', 'N/A'),
"distance": results['distances'][0][i]
})
return SearchResponse(results=formatted)
# Mount Gradio on FastAPI
app = gr.mount_gradio_app(app, demo, path="/")