imurra commited on
Commit
534ed03
·
verified ·
1 Parent(s): 15c714e

Key change: Added demo.launch(server_name="0.0.0.0", server_port=7860) at the end - this keeps the server running!

Browse files
Files changed (1) hide show
  1. app.py +17 -68
app.py CHANGED
@@ -1,37 +1,30 @@
1
- # app.py - Hugging Face Spaces version - FIXED
2
  import os
3
  import zipfile
4
- from fastapi import FastAPI, HTTPException
5
- from fastapi.middleware.cors import CORSMiddleware
6
- from pydantic import BaseModel
7
  import chromadb
8
  from sentence_transformers import SentenceTransformer
9
  import gradio as gr
10
 
11
- # Database setup
12
  DB_PATH = "./medqa_db"
13
- ZIP_PATH = "./medqa_db.zip"
14
-
15
- # Extract database if needed
16
- if not os.path.exists(DB_PATH) and os.path.exists(ZIP_PATH):
17
- print("Extracting database from zip file...")
18
- with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
19
  zip_ref.extractall(".")
20
- print("Database extracted successfully!")
21
 
22
  # Load database and model
23
  print(f"Loading database from: {DB_PATH}")
24
  client = chromadb.PersistentClient(path=DB_PATH)
25
  collection = client.get_collection("medqa")
26
- print(f"Collection loaded with {collection.count()} items")
27
- print("Loading MedCPT model...")
28
  model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
29
- print("Initialization complete!")
30
 
31
- # Search function for Gradio
32
- def search_gradio(query, num_results=3):
33
  if not query.strip():
34
- return "Please enter a query."
 
35
  try:
36
  embedding = model.encode(query).tolist()
37
  results = collection.query(query_embeddings=[embedding], n_results=int(num_results))
@@ -46,62 +39,18 @@ def search_gradio(query, num_results=3):
46
  except Exception as e:
47
  return f"Error: {str(e)}"
48
 
49
- # FastAPI app
50
- app = FastAPI(title="MedQA Search API")
51
-
52
- app.add_middleware(
53
- CORSMiddleware,
54
- allow_origins=["*"],
55
- allow_credentials=True,
56
- allow_methods=["*"],
57
- allow_headers=["*"],
58
- )
59
-
60
- class SearchRequest(BaseModel):
61
- query: str
62
- num_results: int = 3
63
-
64
- class SearchResponse(BaseModel):
65
- results: list[dict]
66
-
67
- @app.get("/")
68
- async def root():
69
- return {
70
- "message": "MedQA Search API",
71
- "status": "running",
72
- "collection_count": collection.count()
73
- }
74
-
75
- @app.post("/search_medqa", response_model=SearchResponse)
76
- async def search_medqa(request: SearchRequest):
77
- try:
78
- embedding = model.encode(request.query).tolist()
79
- results = collection.query(query_embeddings=[embedding], n_results=request.num_results)
80
-
81
- formatted_results = []
82
- for i in range(len(results['documents'][0])):
83
- formatted_results.append({
84
- "example_number": i + 1,
85
- "question": results['documents'][0][i],
86
- "answer": results['metadatas'][0][i].get('answer', 'N/A'),
87
- "distance": results['distances'][0][i]
88
- })
89
- return SearchResponse(results=formatted_results)
90
- except Exception as e:
91
- raise HTTPException(status_code=500, detail=str(e))
92
-
93
- # Gradio interface
94
  demo = gr.Interface(
95
- fn=search_gradio,
96
  inputs=[
97
  gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia", lines=2),
98
  gr.Slider(1, 5, value=3, step=1, label="Number of Results")
99
  ],
100
  outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
101
- title="MedQA Search - USMLE Question Database",
102
- description="Search for similar USMLE Step 1 questions using semantic similarity",
103
  examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
104
  )
105
 
106
- # Mount Gradio to FastAPI
107
- app = gr.mount_gradio_app(app, demo, path="/")
 
 
1
  import os
2
  import zipfile
 
 
 
3
  import chromadb
4
  from sentence_transformers import SentenceTransformer
5
  import gradio as gr
6
 
7
+ # Extract database
8
  DB_PATH = "./medqa_db"
9
+ if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
10
+ print("Extracting database...")
11
+ with zipfile.ZipFile("./medqa_db.zip", 'r') as zip_ref:
 
 
 
12
  zip_ref.extractall(".")
13
+ print("Extracted!")
14
 
15
  # Load database and model
16
  print(f"Loading database from: {DB_PATH}")
17
  client = chromadb.PersistentClient(path=DB_PATH)
18
  collection = client.get_collection("medqa")
19
+ print(f"Loaded {collection.count()} items")
20
+ print("Loading model...")
21
  model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
22
+ print("Ready!")
23
 
24
+ def search_medqa(query, num_results=3):
 
25
  if not query.strip():
26
+ return "Please enter a search query."
27
+
28
  try:
29
  embedding = model.encode(query).tolist()
30
  results = collection.query(query_embeddings=[embedding], n_results=int(num_results))
 
39
  except Exception as e:
40
  return f"Error: {str(e)}"
41
 
42
+ # Create Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  demo = gr.Interface(
44
+ fn=search_medqa,
45
  inputs=[
46
  gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia", lines=2),
47
  gr.Slider(1, 5, value=3, step=1, label="Number of Results")
48
  ],
49
  outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
50
+ title="MedQA Search",
51
+ description="Search for similar USMLE Step 1 questions",
52
  examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
53
  )
54
 
55
+ # Launch - this is the key line!
56
+ demo.launch(server_name="0.0.0.0", server_port=7860)