Spaces:
Sleeping
Sleeping
Update app.py
Browse filesBut the Gradio interface doesn't expose a /search_medqa endpoint. We need to add FastAPI back for the API.
✅ Web UI at: https://imurra-medqa-api-online.hf.space/
✅ API endpoint at: https://imurra-medqa-api-online.hf.space/search_medqa
app.py
CHANGED
|
@@ -3,6 +3,8 @@ import zipfile
|
|
| 3 |
import chromadb
|
| 4 |
from sentence_transformers import SentenceTransformer
|
| 5 |
import gradio as gr
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# Extract database
|
| 8 |
DB_PATH = "./medqa_db"
|
|
@@ -21,14 +23,19 @@ print("Loading model...")
|
|
| 21 |
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
|
| 22 |
print("Ready!")
|
| 23 |
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
if not query.strip():
|
| 26 |
return "Please enter a search query."
|
| 27 |
|
| 28 |
try:
|
| 29 |
-
|
| 30 |
-
results = collection.query(query_embeddings=[embedding], n_results=int(num_results))
|
| 31 |
-
|
| 32 |
output = ""
|
| 33 |
for i in range(len(results['documents'][0])):
|
| 34 |
output += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
|
|
@@ -39,9 +46,8 @@ def search_medqa(query, num_results=3):
|
|
| 39 |
except Exception as e:
|
| 40 |
return f"Error: {str(e)}"
|
| 41 |
|
| 42 |
-
# Create Gradio interface
|
| 43 |
demo = gr.Interface(
|
| 44 |
-
fn=
|
| 45 |
inputs=[
|
| 46 |
gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia", lines=2),
|
| 47 |
gr.Slider(1, 5, value=3, step=1, label="Number of Results")
|
|
@@ -52,5 +58,28 @@ demo = gr.Interface(
|
|
| 52 |
examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
|
| 53 |
)
|
| 54 |
|
| 55 |
-
#
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import chromadb
|
| 4 |
from sentence_transformers import SentenceTransformer
|
| 5 |
import gradio as gr
|
| 6 |
+
from fastapi import FastAPI
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
|
| 9 |
# Extract database
|
| 10 |
DB_PATH = "./medqa_db"
|
|
|
|
| 23 |
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
|
| 24 |
print("Ready!")
|
| 25 |
|
| 26 |
+
# Search function
|
| 27 |
+
def search_function(query, num_results=3):
|
| 28 |
+
embedding = model.encode(query).tolist()
|
| 29 |
+
results = collection.query(query_embeddings=[embedding], n_results=int(num_results))
|
| 30 |
+
return results
|
| 31 |
+
|
| 32 |
+
# Gradio interface
|
| 33 |
+
def search_gradio(query, num_results=3):
|
| 34 |
if not query.strip():
|
| 35 |
return "Please enter a search query."
|
| 36 |
|
| 37 |
try:
|
| 38 |
+
results = search_function(query, num_results)
|
|
|
|
|
|
|
| 39 |
output = ""
|
| 40 |
for i in range(len(results['documents'][0])):
|
| 41 |
output += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
|
|
|
|
| 46 |
except Exception as e:
|
| 47 |
return f"Error: {str(e)}"
|
| 48 |
|
|
|
|
| 49 |
demo = gr.Interface(
|
| 50 |
+
fn=search_gradio,
|
| 51 |
inputs=[
|
| 52 |
gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia", lines=2),
|
| 53 |
gr.Slider(1, 5, value=3, step=1, label="Number of Results")
|
|
|
|
| 58 |
examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
|
| 59 |
)
|
| 60 |
|
| 61 |
+
# FastAPI for ChatGPT
|
| 62 |
+
app = FastAPI()
|
| 63 |
+
|
| 64 |
+
class SearchRequest(BaseModel):
|
| 65 |
+
query: str
|
| 66 |
+
num_results: int = 3
|
| 67 |
+
|
| 68 |
+
class SearchResponse(BaseModel):
|
| 69 |
+
results: list[dict]
|
| 70 |
+
|
| 71 |
+
@app.post("/search_medqa")
|
| 72 |
+
async def api_search(request: SearchRequest):
|
| 73 |
+
results = search_function(request.query, request.num_results)
|
| 74 |
+
formatted = []
|
| 75 |
+
for i in range(len(results['documents'][0])):
|
| 76 |
+
formatted.append({
|
| 77 |
+
"example_number": i + 1,
|
| 78 |
+
"question": results['documents'][0][i],
|
| 79 |
+
"answer": results['metadatas'][0][i].get('answer', 'N/A'),
|
| 80 |
+
"distance": results['distances'][0][i]
|
| 81 |
+
})
|
| 82 |
+
return SearchResponse(results=formatted)
|
| 83 |
+
|
| 84 |
+
# Mount Gradio on FastAPI
|
| 85 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|