Spaces:
Sleeping
Sleeping
File size: 3,579 Bytes
3a8e6a8 847b628 a9d88a0 847b628 15c714e 847b628 a9d88a0 847b628 15c714e 847b628 a9d88a0 15c714e 847b628 15c714e 3a8e6a8 15c714e 3a8e6a8 15c714e 3a8e6a8 15c714e 3a8e6a8 15c714e 847b628 3a8e6a8 15c714e 3a8e6a8 15c714e 3a8e6a8 15c714e 3a8e6a8 6c121ca 15c714e 6c121ca 15c714e 6c121ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
# app.py - Hugging Face Spaces version - FIXED
import os
import zipfile
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import chromadb
from sentence_transformers import SentenceTransformer
import gradio as gr
# Database setup
DB_PATH = "./medqa_db"
ZIP_PATH = "./medqa_db.zip"
# Extract database if needed
if not os.path.exists(DB_PATH) and os.path.exists(ZIP_PATH):
print("Extracting database from zip file...")
with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
zip_ref.extractall(".")
print("Database extracted successfully!")
# Load database and model
print(f"Loading database from: {DB_PATH}")
client = chromadb.PersistentClient(path=DB_PATH)
collection = client.get_collection("medqa")
print(f"Collection loaded with {collection.count()} items")
print("Loading MedCPT model...")
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
print("Initialization complete!")
# Search function for Gradio
def search_gradio(query, num_results=3):
if not query.strip():
return "Please enter a query."
try:
embedding = model.encode(query).tolist()
results = collection.query(query_embeddings=[embedding], n_results=int(num_results))
output = ""
for i in range(len(results['documents'][0])):
output += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
output += results['documents'][0][i] + "\n"
output += f"\nAnswer: {results['metadatas'][0][i].get('answer', 'N/A')}\n"
output += f"Similarity: {1 - results['distances'][0][i]:.3f}\n"
return output
except Exception as e:
return f"Error: {str(e)}"
# FastAPI app
app = FastAPI(title="MedQA Search API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class SearchRequest(BaseModel):
query: str
num_results: int = 3
class SearchResponse(BaseModel):
results: list[dict]
@app.get("/")
async def root():
return {
"message": "MedQA Search API",
"status": "running",
"collection_count": collection.count()
}
@app.post("/search_medqa", response_model=SearchResponse)
async def search_medqa(request: SearchRequest):
try:
embedding = model.encode(request.query).tolist()
results = collection.query(query_embeddings=[embedding], n_results=request.num_results)
formatted_results = []
for i in range(len(results['documents'][0])):
formatted_results.append({
"example_number": i + 1,
"question": results['documents'][0][i],
"answer": results['metadatas'][0][i].get('answer', 'N/A'),
"distance": results['distances'][0][i]
})
return SearchResponse(results=formatted_results)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Gradio interface
demo = gr.Interface(
fn=search_gradio,
inputs=[
gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia", lines=2),
gr.Slider(1, 5, value=3, step=1, label="Number of Results")
],
outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
title="MedQA Search - USMLE Question Database",
description="Search for similar USMLE Step 1 questions using semantic similarity",
examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
)
# Mount Gradio to FastAPI
app = gr.mount_gradio_app(app, demo, path="/") |