imurra's picture
Syntax error in the file! There's a typo on line 148
6c121ca verified
raw
history blame
4.87 kB
# app.py - Hugging Face Spaces version - FIXED
import os
import zipfile
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import chromadb
from sentence_transformers import SentenceTransformer
import gradio as gr
# Database path
DB_PATH = "./medqa_db"
ZIP_PATH = "./medqa_db.zip"
# Extract database if needed
if not os.path.exists(DB_PATH) and os.path.exists(ZIP_PATH):
print("Extracting database from zip file...")
with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
zip_ref.extractall(".")
print("Database extracted successfully!")
# Initialize
print(f"Loading database from: {DB_PATH}")
client = chromadb.PersistentClient(path=DB_PATH)
collection = client.get_collection("medqa")
print(f"Collection loaded with {collection.count()} items")
print(f"Loading MedCPT model...")
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
print("Initialization complete!")
# Gradio interface function
def search_interface(query: str, num_results: int = 3):
"""Simple web interface for testing"""
if not query.strip():
return "Please enter a search query."
try:
embedding = model.encode(query).tolist()
results = collection.query(
query_embeddings=[embedding],
n_results=int(num_results)
)
if not results['documents'][0]:
return "No results found."
output = ""
for i in range(len(results['documents'][0])):
output += f"\n{'='*60}\n"
output += f"Example {i+1}\n"
output += f"{'='*60}\n"
output += results['documents'][0][i] + "\n"
output += f"\nAnswer: {results['metadatas'][0][i].get('answer', 'N/A')}\n"
output += f"Similarity: {1 - results['distances'][0][i]:.3f}\n"
return output
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="MedQA Search") as demo:
gr.Markdown("# MedQA Search - USMLE Question Database")
gr.Markdown("Search for similar USMLE Step 1 questions using semantic similarity")
with gr.Row():
with gr.Column():
query_input = gr.Textbox(
label="Medical Topic or Clinical Scenario",
placeholder="e.g., hyponatremia",
lines=2
)
num_results_slider = gr.Slider(
minimum=1,
maximum=5,
value=3,
step=1,
label="Number of Examples"
)
search_btn = gr.Button("Search", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Similar USMLE Questions",
lines=25,
max_lines=50
)
search_btn.click(
fn=search_interface,
inputs=[query_input, num_results_slider],
outputs=output_text
)
gr.Examples(
examples=[
["hyponatremia", 3],
["myocardial infarction", 2],
["diabetic ketoacidosis", 3]
],
inputs=[query_input, num_results_slider]
)
# FastAPI for API endpoints
app = FastAPI(title="MedQA Search API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class SearchRequest(BaseModel):
query: str
num_results: int = 3
class SearchResponse(BaseModel):
results: list[dict]
@app.get("/")
async def root():
return {
"message": "MedQA Search API - Hugging Face Version",
"status": "running",
"collection_count": collection.count()
}
@app.post("/search_medqa", response_model=SearchResponse)
async def search_medqa(request: SearchRequest):
"""Search MedQA database for similar USMLE questions"""
try:
embedding = model.encode(request.query).tolist()
results = collection.query(
query_embeddings=[embedding],
n_results=request.num_results
)
formatted_results = []
for i in range(len(results['documents'][0])):
formatted_results.append({
"example_number": i + 1,
"question": results['documents'][0][i],
"answer": results['metadatas'][0][i].get('answer', 'N/A'),
"distance": results['distances'][0][i] if 'distances' in results else None
})
return SearchResponse(results=formatted_results)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
return {"status": "healthy", "items": collection.count()}
# Mount Gradio on FastAPI
app = gr.mount_gradio_app(app, demo, path="/")