Spaces:
Sleeping
Sleeping
| # app.py - Hugging Face Spaces version - FIXED | |
| import os | |
| import zipfile | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| import gradio as gr | |
| # Database path | |
| DB_PATH = "./medqa_db" | |
| ZIP_PATH = "./medqa_db.zip" | |
| # Extract database if needed | |
| if not os.path.exists(DB_PATH) and os.path.exists(ZIP_PATH): | |
| print("Extracting database from zip file...") | |
| with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref: | |
| zip_ref.extractall(".") | |
| print("Database extracted successfully!") | |
| # Initialize | |
| print(f"Loading database from: {DB_PATH}") | |
| client = chromadb.PersistentClient(path=DB_PATH) | |
| collection = client.get_collection("medqa") | |
| print(f"Collection loaded with {collection.count()} items") | |
| print(f"Loading MedCPT model...") | |
| model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') | |
| print("Initialization complete!") | |
| # Gradio interface function | |
| def search_interface(query: str, num_results: int = 3): | |
| """Simple web interface for testing""" | |
| if not query.strip(): | |
| return "Please enter a search query." | |
| try: | |
| embedding = model.encode(query).tolist() | |
| results = collection.query( | |
| query_embeddings=[embedding], | |
| n_results=int(num_results) | |
| ) | |
| if not results['documents'][0]: | |
| return "No results found." | |
| output = "" | |
| for i in range(len(results['documents'][0])): | |
| output += f"\n{'='*60}\n" | |
| output += f"Example {i+1}\n" | |
| output += f"{'='*60}\n" | |
| output += results['documents'][0][i] + "\n" | |
| output += f"\nAnswer: {results['metadatas'][0][i].get('answer', 'N/A')}\n" | |
| output += f"Similarity: {1 - results['distances'][0][i]:.3f}\n" | |
| return output | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="MedQA Search") as demo: | |
| gr.Markdown("# MedQA Search - USMLE Question Database") | |
| gr.Markdown("Search for similar USMLE Step 1 questions using semantic similarity") | |
| with gr.Row(): | |
| with gr.Column(): | |
| query_input = gr.Textbox( | |
| label="Medical Topic or Clinical Scenario", | |
| placeholder="e.g., hyponatremia", | |
| lines=2 | |
| ) | |
| num_results_slider = gr.Slider( | |
| minimum=1, | |
| maximum=5, | |
| value=3, | |
| step=1, | |
| label="Number of Examples" | |
| ) | |
| search_btn = gr.Button("Search", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Similar USMLE Questions", | |
| lines=25, | |
| max_lines=50 | |
| ) | |
| search_btn.click( | |
| fn=search_interface, | |
| inputs=[query_input, num_results_slider], | |
| outputs=output_text | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["hyponatremia", 3], | |
| ["myocardial infarction", 2], | |
| ["diabetic ketoacidosis", 3] | |
| ], | |
| inputs=[query_input, num_results_slider] | |
| ) | |
| # FastAPI for API endpoints | |
| app = FastAPI(title="MedQA Search API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class SearchRequest(BaseModel): | |
| query: str | |
| num_results: int = 3 | |
| class SearchResponse(BaseModel): | |
| results: list[dict] | |
| async def root(): | |
| return { | |
| "message": "MedQA Search API - Hugging Face Version", | |
| "status": "running", | |
| "collection_count": collection.count() | |
| } | |
| async def search_medqa(request: SearchRequest): | |
| """Search MedQA database for similar USMLE questions""" | |
| try: | |
| embedding = model.encode(request.query).tolist() | |
| results = collection.query( | |
| query_embeddings=[embedding], | |
| n_results=request.num_results | |
| ) | |
| formatted_results = [] | |
| for i in range(len(results['documents'][0])): | |
| formatted_results.append({ | |
| "example_number": i + 1, | |
| "question": results['documents'][0][i], | |
| "answer": results['metadatas'][0][i].get('answer', 'N/A'), | |
| "distance": results['distances'][0][i] if 'distances' in results else None | |
| }) | |
| return SearchResponse(results=formatted_results) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health(): | |
| return {"status": "healthy", "items": collection.count()} | |
| # Mount Gradio on FastAPI | |
| app = gr.mount_gradio_app(app, demo, path="/") |