imurra commited on
Commit
98cb2a6
·
verified ·
1 Parent(s): f944c35

Need to add FastAPI endpoint by using only gr.mount_gradio_app() method (no demo.launch()), then configure ChatGPT Custom GPT Action with the API endpoint.

Browse files
Files changed (1) hide show
  1. app.py +29 -47
app.py CHANGED
@@ -6,80 +6,62 @@ import gradio as gr
6
  from fastapi import FastAPI
7
  from pydantic import BaseModel
8
 
9
- # Extract database
10
  DB_PATH = "./medqa_db"
11
  if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
12
- print("Extracting database...")
13
- with zipfile.ZipFile("./medqa_db.zip", 'r') as zip_ref:
14
- zip_ref.extractall(".")
15
- print("Extracted!")
16
 
17
- # Load database and model
18
- print(f"Loading database from: {DB_PATH}")
19
  client = chromadb.PersistentClient(path=DB_PATH)
20
  collection = client.get_collection("medqa")
21
- print(f"Loaded {collection.count()} items")
22
- print("Loading model...")
23
  model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
24
- print("Ready!")
25
 
26
  # Search function
27
- def search_function(query, num_results=3):
28
- embedding = model.encode(query).tolist()
29
- results = collection.query(query_embeddings=[embedding], n_results=int(num_results))
30
- return results
31
 
32
- # Gradio interface
33
- def search_gradio(query, num_results=3):
34
  if not query.strip():
35
- return "Please enter a search query."
36
-
37
  try:
38
- results = search_function(query, num_results)
39
- output = ""
40
- for i in range(len(results['documents'][0])):
41
- output += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
42
- output += results['documents'][0][i] + "\n"
43
- output += f"\nAnswer: {results['metadatas'][0][i].get('answer', 'N/A')}\n"
44
- output += f"Similarity: {1 - results['distances'][0][i]:.3f}\n"
45
- return output
46
  except Exception as e:
47
- return f"Error: {str(e)}"
48
 
49
  demo = gr.Interface(
50
- fn=search_gradio,
51
  inputs=[
52
- gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia", lines=2),
53
- gr.Slider(1, 5, value=3, step=1, label="Number of Results")
54
  ],
55
  outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
56
  title="MedQA Search",
57
- description="Search for similar USMLE Step 1 questions",
58
  examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
59
  )
60
 
61
- # FastAPI for ChatGPT
62
  app = FastAPI()
63
 
64
  class SearchRequest(BaseModel):
65
  query: str
66
  num_results: int = 3
67
 
68
- class SearchResponse(BaseModel):
69
- results: list[dict]
70
-
71
  @app.post("/search_medqa")
72
- async def api_search(request: SearchRequest):
73
- results = search_function(request.query, request.num_results)
74
- formatted = []
75
- for i in range(len(results['documents'][0])):
76
- formatted.append({
77
- "example_number": i + 1,
78
- "question": results['documents'][0][i],
79
- "answer": results['metadatas'][0][i].get('answer', 'N/A'),
80
- "distance": results['distances'][0][i]
81
- })
82
- return SearchResponse(results=formatted)
83
 
84
- # Mount Gradio on FastAPI
85
  app = gr.mount_gradio_app(app, demo, path="/")
 
6
  from fastapi import FastAPI
7
  from pydantic import BaseModel
8
 
9
+ # Extract and load database
10
  DB_PATH = "./medqa_db"
11
  if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
12
+ with zipfile.ZipFile("./medqa_db.zip", 'r') as z:
13
+ z.extractall(".")
 
 
14
 
 
 
15
  client = chromadb.PersistentClient(path=DB_PATH)
16
  collection = client.get_collection("medqa")
 
 
17
  model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
 
18
 
19
  # Search function
20
+ def search(query, num_results=3):
21
+ emb = model.encode(query).tolist()
22
+ return collection.query(query_embeddings=[emb], n_results=int(num_results))
 
23
 
24
+ # Gradio UI
25
+ def ui_search(query, num_results=3):
26
  if not query.strip():
27
+ return "Enter a query"
 
28
  try:
29
+ r = search(query, num_results)
30
+ out = ""
31
+ for i in range(len(r['documents'][0])):
32
+ out += f"\n{'='*60}\nExample {i+1}\n{'='*60}\n"
33
+ out += r['documents'][0][i] + f"\n\nAnswer: {r['metadatas'][0][i].get('answer', 'N/A')}\n"
34
+ out += f"Similarity: {1 - r['distances'][0][i]:.3f}\n"
35
+ return out
 
36
  except Exception as e:
37
+ return f"Error: {e}"
38
 
39
  demo = gr.Interface(
40
+ fn=ui_search,
41
  inputs=[
42
+ gr.Textbox(label="Medical Query", placeholder="e.g., hyponatremia"),
43
+ gr.Slider(1, 5, value=3, step=1, label="Results")
44
  ],
45
  outputs=gr.Textbox(label="Similar USMLE Questions", lines=20),
46
  title="MedQA Search",
 
47
  examples=[["hyponatremia", 3], ["myocardial infarction", 2]]
48
  )
49
 
50
+ # FastAPI
51
  app = FastAPI()
52
 
53
  class SearchRequest(BaseModel):
54
  query: str
55
  num_results: int = 3
56
 
 
 
 
57
  @app.post("/search_medqa")
58
+ def api_search(req: SearchRequest):
59
+ r = search(req.query, req.num_results)
60
+ return {"results": [{
61
+ "example_number": i+1,
62
+ "question": r['documents'][0][i],
63
+ "answer": r['metadatas'][0][i].get('answer', 'N/A'),
64
+ "distance": r['distances'][0][i]
65
+ } for i in range(len(r['documents'][0]))]}
 
 
 
66
 
 
67
  app = gr.mount_gradio_app(app, demo, path="/")