import os os.environ['ANONYMIZED_TELEMETRY'] = 'False' import zipfile import chromadb from sentence_transformers import SentenceTransformer import gradio as gr from fastapi import FastAPI from pydantic import BaseModel # Extract and load database DB_PATH = "./medqa_db" if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"): with zipfile.ZipFile("./medqa_db.zip", 'r') as z: z.extractall(".") client = chromadb.PersistentClient(path=DB_PATH) collection = client.get_collection("medqa") model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') # Search function def search(query, num_results=3): emb = model.encode(query).tolist() return collection.query(query_embeddings=[emb], n_results=int(num_results)) # Gradio UI def ui_search(query, num_results=3): if not query.strip(): return "Enter a query" try: r = search(query, num_results) out = "" for i in range(len(r['documents'][0])): out += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n" # Get the full question text question_text = r['documents'][0][i] # DEBUG: Show raw text out += "DEBUG RAW TEXT:\n" out += repr(question_text[:500]) + "\n" out += "="*60 + "\n\n" # Parse question and answer choices import re # Look for answer choices pattern (A. or A) followed by text) lines = question_text.split('\n') question_part = [] choices_part = [] in_choices = False for line in lines: # Check if line starts with A-E followed by . or ) if re.match(r'^[A-E][\.\)]', line.strip()): in_choices = True choices_part.append(line) elif in_choices: # Continue collecting choices if they span multiple lines if line.strip() and not re.match(r'^[A-E][\.\)]', line.strip()): choices_part[-1] += " " + line.strip() elif re.match(r'^[A-E][\.\)]', line.strip()): choices_part.append(line) else: question_part.append(line) # Display question out += '\n'.join(question_part).strip() + "\n\n" # Display choices if found if choices_part: out += "Answer Choices:\n" for choice in choices_part: out += choice.strip() + "\n" out += "\n" out += f"Correct Answer: {r['metadatas'][0][i].get('answer', 'N/A')}\n" out += f"Similarity: {1 - r['distances'][0][i]:.3f}\n" return out except Exception as e: return f"Error: {e}" demo = gr.Interface( fn=ui_search, inputs=[ gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia"), gr.Slider(1, 5, value=3, step=1, label="Results") ], outputs=gr.Textbox(label="Similar USMLE Questions", lines=20), title="MedQA Search", examples=[["hyponatremia", 3], ["myocardial infarction", 2]] ) # FastAPI app = FastAPI() class SearchRequest(BaseModel): query: str num_results: int = 3 @app.post("/search_medqa") def api_search(req: SearchRequest): r = search(req.query, req.num_results) return {"results": [{ "example_number": i+1, "question": r['documents'][0][i], "answer": r['metadatas'][0][i].get('answer', 'N/A'), "distance": r['distances'][0][i] } for i in range(len(r['documents'][0]))]} app = gr.mount_gradio_app(app, demo, path="/") # Launch the server if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)