Spaces:
Sleeping
Sleeping
Need to add FastAPI endpoint by using only gr.mount_gradio_app() method (no demo.launch()), then configure ChatGPT Custom GPT Action with the API endpoint.
Browse files
app.py
CHANGED
|
@@ -6,80 +6,62 @@ import gradio as gr
|
|
| 6 |
from fastapi import FastAPI
|
| 7 |
from pydantic import BaseModel
|
| 8 |
|
| 9 |
-
# Extract database
|
| 10 |
DB_PATH = "./medqa_db"
|
| 11 |
if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
zip_ref.extractall(".")
|
| 15 |
-
print("Extracted!")
|
| 16 |
|
| 17 |
-
# Load database and model
|
| 18 |
-
print(f"Loading database from: {DB_PATH}")
|
| 19 |
client = chromadb.PersistentClient(path=DB_PATH)
|
| 20 |
collection = client.get_collection("medqa")
|
| 21 |
-
print(f"Loaded {collection.count()} items")
|
| 22 |
-
print("Loading model...")
|
| 23 |
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
|
| 24 |
-
print("Ready!")
|
| 25 |
|
| 26 |
# Search function
|
| 27 |
-
def
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
return results
|
| 31 |
|
| 32 |
-
# Gradio
|
| 33 |
-
def
|
| 34 |
if not query.strip():
|
| 35 |
-
return "
|
| 36 |
-
|
| 37 |
try:
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
for i in range(len(
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
return output
|
| 46 |
except Exception as e:
|
| 47 |
-
return f"Error: {
|
| 48 |
|
| 49 |
demo = gr.Interface(
|
| 50 |
-
fn=
|
| 51 |
inputs=[
|
| 52 |
-
gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia"
|
| 53 |
-
gr.Slider(1, 5, value=3, step=1, label="
|
| 54 |
],
|
| 55 |
outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
|
| 56 |
title="MedQA Search",
|
| 57 |
-
description="Search for similar USMLE Step 1 questions",
|
| 58 |
examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
|
| 59 |
)
|
| 60 |
|
| 61 |
-
# FastAPI
|
| 62 |
app = FastAPI()
|
| 63 |
|
| 64 |
class SearchRequest(BaseModel):
|
| 65 |
query: str
|
| 66 |
num_results: int = 3
|
| 67 |
|
| 68 |
-
class SearchResponse(BaseModel):
|
| 69 |
-
results: list[dict]
|
| 70 |
-
|
| 71 |
@app.post("/search_medqa")
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
"distance": results['distances'][0][i]
|
| 81 |
-
})
|
| 82 |
-
return SearchResponse(results=formatted)
|
| 83 |
|
| 84 |
-
# Mount Gradio on FastAPI
|
| 85 |
app = gr.mount_gradio_app(app, demo, path="/")
|
|
|
|
| 6 |
from fastapi import FastAPI
|
| 7 |
from pydantic import BaseModel
|
| 8 |
|
| 9 |
+
# Extract and load database
|
| 10 |
DB_PATH = "./medqa_db"
|
| 11 |
if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
|
| 12 |
+
with zipfile.ZipFile("./medqa_db.zip", 'r') as z:
|
| 13 |
+
z.extractall(".")
|
|
|
|
|
|
|
| 14 |
|
|
|
|
|
|
|
| 15 |
client = chromadb.PersistentClient(path=DB_PATH)
|
| 16 |
collection = client.get_collection("medqa")
|
|
|
|
|
|
|
| 17 |
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
|
|
|
|
| 18 |
|
| 19 |
# Search function
|
| 20 |
+
def search(query, num_results=3):
|
| 21 |
+
emb = model.encode(query).tolist()
|
| 22 |
+
return collection.query(query_embeddings=[emb], n_results=int(num_results))
|
|
|
|
| 23 |
|
| 24 |
+
# Gradio UI
|
| 25 |
+
def ui_search(query, num_results=3):
|
| 26 |
if not query.strip():
|
| 27 |
+
return "Enter a query"
|
|
|
|
| 28 |
try:
|
| 29 |
+
r = search(query, num_results)
|
| 30 |
+
out = ""
|
| 31 |
+
for i in range(len(r['documents'][0])):
|
| 32 |
+
out += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
|
| 33 |
+
out += r['documents'][0][i] + f"\n\nAnswer: {r['metadatas'][0][i].get('answer', 'N/A')}\n"
|
| 34 |
+
out += f"Similarity: {1 - r['distances'][0][i]:.3f}\n"
|
| 35 |
+
return out
|
|
|
|
| 36 |
except Exception as e:
|
| 37 |
+
return f"Error: {e}"
|
| 38 |
|
| 39 |
demo = gr.Interface(
|
| 40 |
+
fn=ui_search,
|
| 41 |
inputs=[
|
| 42 |
+
gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia"),
|
| 43 |
+
gr.Slider(1, 5, value=3, step=1, label="Results")
|
| 44 |
],
|
| 45 |
outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
|
| 46 |
title="MedQA Search",
|
|
|
|
| 47 |
examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
|
| 48 |
)
|
| 49 |
|
| 50 |
+
# FastAPI
|
| 51 |
app = FastAPI()
|
| 52 |
|
| 53 |
class SearchRequest(BaseModel):
|
| 54 |
query: str
|
| 55 |
num_results: int = 3
|
| 56 |
|
|
|
|
|
|
|
|
|
|
| 57 |
@app.post("/search_medqa")
|
| 58 |
+
def api_search(req: SearchRequest):
|
| 59 |
+
r = search(req.query, req.num_results)
|
| 60 |
+
return {"results": [{
|
| 61 |
+
"example_number": i+1,
|
| 62 |
+
"question": r['documents'][0][i],
|
| 63 |
+
"answer": r['metadatas'][0][i].get('answer', 'N/A'),
|
| 64 |
+
"distance": r['distances'][0][i]
|
| 65 |
+
} for i in range(len(r['documents'][0]))]}
|
|
|
|
|
|
|
|
|
|
| 66 |
|
|
|
|
| 67 |
app = gr.mount_gradio_app(app, demo, path="/")
|