imurra's picture
to see raw question
249c206 verified
raw
history blame
3.86 kB
import os
os.environ['ANONYMIZED_TELEMETRY'] = 'False'
import zipfile
import chromadb
from sentence_transformers import SentenceTransformer
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
# Extract and load database
DB_PATH = "./medqa_db"
if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
with zipfile.ZipFile("./medqa_db.zip", 'r') as z:
z.extractall(".")
client = chromadb.PersistentClient(path=DB_PATH)
collection = client.get_collection("medqa")
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
# Search function
def search(query, num_results=3):
emb = model.encode(query).tolist()
return collection.query(query_embeddings=[emb], n_results=int(num_results))
# Gradio UI
def ui_search(query, num_results=3):
if not query.strip():
return "Enter a query"
try:
r = search(query, num_results)
out = ""
for i in range(len(r['documents'][0])):
out += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
# Get the full question text
question_text = r['documents'][0][i]
# DEBUG: Show raw text
out += "DEBUG RAW TEXT:\n"
out += repr(question_text[:500]) + "\n"
out += "="*60 + "\n\n"
# Parse question and answer choices
import re
# Look for answer choices pattern (A. or A) followed by text)
lines = question_text.split('\n')
question_part = []
choices_part = []
in_choices = False
for line in lines:
# Check if line starts with A-E followed by . or )
if re.match(r'^[A-E][\.\)]', line.strip()):
in_choices = True
choices_part.append(line)
elif in_choices:
# Continue collecting choices if they span multiple lines
if line.strip() and not re.match(r'^[A-E][\.\)]', line.strip()):
choices_part[-1] += " " + line.strip()
elif re.match(r'^[A-E][\.\)]', line.strip()):
choices_part.append(line)
else:
question_part.append(line)
# Display question
out += '\n'.join(question_part).strip() + "\n\n"
# Display choices if found
if choices_part:
out += "Answer Choices:\n"
for choice in choices_part:
out += choice.strip() + "\n"
out += "\n"
out += f"Correct Answer: {r['metadatas'][0][i].get('answer', 'N/A')}\n"
out += f"Similarity: {1 - r['distances'][0][i]:.3f}\n"
return out
except Exception as e:
return f"Error: {e}"
demo = gr.Interface(
fn=ui_search,
inputs=[
gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia"),
gr.Slider(1, 5, value=3, step=1, label="Results")
],
outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
title="MedQA Search",
examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
)
# FastAPI
app = FastAPI()
class SearchRequest(BaseModel):
query: str
num_results: int = 3
@app.post("/search_medqa")
def api_search(req: SearchRequest):
r = search(req.query, req.num_results)
return {"results": [{
"example_number": i+1,
"question": r['documents'][0][i],
"answer": r['metadatas'][0][i].get('answer', 'N/A'),
"distance": r['distances'][0][i]
} for i in range(len(r['documents'][0]))]}
app = gr.mount_gradio_app(app, demo, path="/")
# Launch the server
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)