Spaces:
Sleeping
Sleeping
| import os | |
| os.environ['ANONYMIZED_TELEMETRY'] = 'False' | |
| import zipfile | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| import gradio as gr | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| # Extract and load database | |
| DB_PATH = "./medqa_db" | |
| if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"): | |
| with zipfile.ZipFile("./medqa_db.zip", 'r') as z: | |
| z.extractall(".") | |
| client = chromadb.PersistentClient(path=DB_PATH) | |
| collection = client.get_collection("medqa") | |
| model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') | |
| # Search function | |
| def search(query, num_results=3): | |
| emb = model.encode(query).tolist() | |
| return collection.query(query_embeddings=[emb], n_results=int(num_results)) | |
| # Gradio UI | |
| def ui_search(query, num_results=3): | |
| if not query.strip(): | |
| return "Enter a query" | |
| try: | |
| r = search(query, num_results) | |
| out = "" | |
| for i in range(len(r['documents'][0])): | |
| out += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n" | |
| # Get the full question text | |
| question_text = r['documents'][0][i] | |
| # DEBUG: Show raw text | |
| out += "DEBUG RAW TEXT:\n" | |
| out += repr(question_text[:500]) + "\n" | |
| out += "="*60 + "\n\n" | |
| # Parse question and answer choices | |
| import re | |
| # Look for answer choices pattern (A. or A) followed by text) | |
| lines = question_text.split('\n') | |
| question_part = [] | |
| choices_part = [] | |
| in_choices = False | |
| for line in lines: | |
| # Check if line starts with A-E followed by . or ) | |
| if re.match(r'^[A-E][\.\)]', line.strip()): | |
| in_choices = True | |
| choices_part.append(line) | |
| elif in_choices: | |
| # Continue collecting choices if they span multiple lines | |
| if line.strip() and not re.match(r'^[A-E][\.\)]', line.strip()): | |
| choices_part[-1] += " " + line.strip() | |
| elif re.match(r'^[A-E][\.\)]', line.strip()): | |
| choices_part.append(line) | |
| else: | |
| question_part.append(line) | |
| # Display question | |
| out += '\n'.join(question_part).strip() + "\n\n" | |
| # Display choices if found | |
| if choices_part: | |
| out += "Answer Choices:\n" | |
| for choice in choices_part: | |
| out += choice.strip() + "\n" | |
| out += "\n" | |
| out += f"Correct Answer: {r['metadatas'][0][i].get('answer', 'N/A')}\n" | |
| out += f"Similarity: {1 - r['distances'][0][i]:.3f}\n" | |
| return out | |
| except Exception as e: | |
| return f"Error: {e}" | |
| demo = gr.Interface( | |
| fn=ui_search, | |
| inputs=[ | |
| gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia"), | |
| gr.Slider(1, 5, value=3, step=1, label="Results") | |
| ], | |
| outputs=gr.Textbox(label="Similar USMLE Questions", lines=20), | |
| title="MedQA Search", | |
| examples=[["hyponatremia", 3], ["myocardial infarction", 2]] | |
| ) | |
| # FastAPI | |
| app = FastAPI() | |
| class SearchRequest(BaseModel): | |
| query: str | |
| num_results: int = 3 | |
| def api_search(req: SearchRequest): | |
| r = search(req.query, req.num_results) | |
| return {"results": [{ | |
| "example_number": i+1, | |
| "question": r['documents'][0][i], | |
| "answer": r['metadatas'][0][i].get('answer', 'N/A'), | |
| "distance": r['distances'][0][i] | |
| } for i in range(len(r['documents'][0]))]} | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| # Launch the server | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |