NurseLex-Match / app.py
NurseCitizenDeveloper's picture
Upload app.py with huggingface_hub
c70a242 verified
import os
import base64
import pickle
import numpy as np
import gradio as gr
from perplexity import Perplexity
import json
# ─── Configuration ───────────────────────────────────────────────
MODEL_NAME = "pplx-embed-v1-0.6b"
DB_FILE = "case_embeddings.pkl"
# ─── Load Vector Database ────────────────────────────────────────
print(f"Loading vector database from {DB_FILE}...")
if not os.path.exists(DB_FILE):
raise FileNotFoundError(
f"{DB_FILE} not found. Please run create_case_db.py first! "
"(Note: This takes ~12 hours for the full dataset)."
)
with open(DB_FILE, 'rb') as f:
database = pickle.load(f)
cases = database["cases"]
legislation = database["legislation"]
mapping = database["mapping"]
corpus_embeddings = database["embeddings"]
# Normalize corpus embeddings for cosine similarity
corpus_norms = np.linalg.norm(corpus_embeddings, axis=1, keepdims=True)
corpus_embeddings_normalized = corpus_embeddings / np.maximum(corpus_norms, 1e-8)
# Pre-calculate indices for each type for faster filtering
case_indices = [i for i, m in enumerate(mapping) if m["type"] == "case"]
leg_indices = [i for i, m in enumerate(mapping) if m["type"] == "legislation"]
case_embeddings = corpus_embeddings_normalized[case_indices]
leg_embeddings = corpus_embeddings_normalized[leg_indices]
print(f"βœ… Loaded {len(cases)} clinical cases and {len(legislation)} legislation sections.")
def decode_embedding(b64_string):
"""Decode a base64-encoded int8 embedding to float32."""
return np.frombuffer(base64.b64decode(b64_string), dtype=np.int8).astype(np.float32)
def format_nursing_reasoning(reasoning):
"""Format the nursing_reasoning dictionary into readable markdown."""
if not isinstance(reasoning, dict):
return str(reasoning)
output = ""
if "Differential Diagnosis" in reasoning:
output += f"**Differential Diagnosis:**\n{reasoning['Differential Diagnosis']}\n\n"
if "Diagnostic Tests" in reasoning:
output += f"**Diagnostic Tests:**\n{reasoning['Diagnostic Tests']}\n\n"
if "Management Plan" in reasoning:
output += f"**Management Plan:**\n{reasoning['Management Plan']}\n\n"
if "Underlying Mechanism" in reasoning:
output += f"**Underlying Mechanism:**\n{reasoning['Underlying Mechanism']}\n\n"
if "Relevant Medications" in reasoning:
output += f"**Relevant Medications:**\n{reasoning['Relevant Medications']}\n\n"
return output
def perform_search(api_key, query, top_k, search_type):
"""Perform actual semantic search using the Perplexity API."""
if not api_key or not api_key.strip():
return "⚠️ **Please enter your Perplexity API Key** above.", ""
if not query or not query.strip():
return "Please enter a clinical presentation or legal query.", ""
try:
# Embed the query
client = Perplexity(api_key=api_key.strip())
response = client.embeddings.create(
input=[query],
model=MODEL_NAME
)
query_embedding = decode_embedding(response.data[0].embedding)
# Normalize query
query_norm = np.linalg.norm(query_embedding)
query_embedding_normalized = query_embedding / max(query_norm, 1e-8)
output = f"### πŸ” Results for: *\"{query}\"*\n\n---\n"
cost_str = ""
if hasattr(response, 'usage') and response.usage:
cost_str = f"API Cost: ${response.usage.cost.total_cost:.6f}"
if search_type == "case":
# Match against Clinical Cases
similarities = np.dot(case_embeddings, query_embedding_normalized)
top_k_indices = np.argsort(similarities)[-top_k:][::-1]
for i, local_idx in enumerate(top_k_indices):
score = similarities[local_idx]
global_idx = case_indices[local_idx]
case_item = cases[mapping[global_idx]["source_idx"]]
vignette = case_item.get("original_vignette", "")
reasoning = case_item.get("nursing_reasoning", {})
case_id = case_item.get("id", "Unknown")
output += f"#### Result {i+1} β€” Relevance: {score:.3f} (Case #{case_id})\n"
output += f"**Patient Presentation:**\n> {vignette}\n\n"
output += f"<details><summary><b>View Nursing Reasoning</b></summary>\n\n"
output += f"{format_nursing_reasoning(reasoning)}\n</details>\n\n---\n"
else:
# Match against Legislation
similarities = np.dot(leg_embeddings, query_embedding_normalized)
top_k_indices = np.argsort(similarities)[-top_k:][::-1]
for i, local_idx in enumerate(top_k_indices):
score = similarities[local_idx]
global_idx = leg_indices[local_idx]
leg_item = legislation[mapping[global_idx]["source_idx"]]
title = leg_item.get("title", "")
leg_id = leg_item.get("legislation_id", "")
text = leg_item.get("text", "")
output += f"#### Result {i+1} β€” Relevance: {score:.3f} (Act: {leg_id})\n"
output += f"βš–οΈ **{title}**\n\n"
output += f"{text}\n\n---\n"
return output, cost_str
except Exception as e:
error_msg = str(e)
if "401" in error_msg or "auth" in error_msg.lower():
return "❌ **Invalid API Key.** Check your key and try again.", ""
return f"❌ Error: {error_msg}", ""
def search_cases(api_key, query, top_k):
return perform_search(api_key, query, top_k, search_type="case")
def search_legislation(api_key, query, top_k):
return perform_search(api_key, query, top_k, search_type="legislation")
# ─── Gradio UI ───────────────────────────────────────────────────
with gr.Blocks(title="NurseLex-Match") as app:
gr.Markdown(
"""
# 🩺 NurseLex-Match
### Clinical Case Similarity & Legal Lookup
*Powered by Perplexity Embeddings (`pplx-embed-v1-0.6b`)*
Retrieve similar historical cases from the **NurseReason-Dataset** or search UK nursing law in **NurseLex**.
"""
)
with gr.Accordion("πŸ”‘ API Key (BYOK β€” Bring Your Own Key)", open=True):
gr.Markdown(
"Your key is **never stored** β€” it is used only for this session to query the embeddings API. "
)
api_key_input = gr.Textbox(
label="Perplexity API Key",
placeholder="pplx-...",
type="password",
lines=1
)
api_cost = gr.Markdown("*API Cost: $0.000000*")
with gr.Tabs():
# TAB 1: Clinical Cases
with gr.TabItem("πŸ₯ Clinical Case Matcher"):
with gr.Row():
with gr.Column(scale=3):
case_search_input = gr.Textbox(
label="Describe the patient presentation:",
placeholder="e.g., 72-year-old presenting with acute confusion and suspected UTI...",
lines=3
)
with gr.Column(scale=1):
case_top_k = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Results")
case_search_btn = gr.Button("πŸ” Find Similar Cases", variant="primary")
gr.Examples(
examples=[
"Patient is a 45-year-old female complaining of sharp left lower quadrant abdominal pain radiating to the back.",
"Elderly patient with a history of heart failure presenting with increased shortness of breath and pitting edema.",
"Post-operative patient day 2 showing signs of infection at the wound site with low-grade fever."
],
inputs=case_search_input
)
case_results = gr.Markdown("*Similar cases will appear here...*")
# TAB 2: Nursing Legislation
with gr.TabItem("βš–οΈ Legal Lookup"):
with gr.Row():
with gr.Column(scale=3):
leg_search_input = gr.Textbox(
label="Describe the legal context or policy question:",
placeholder="e.g., What are the rules regarding compulsory admission under the Mental Health Act?",
lines=3
)
with gr.Column(scale=1):
leg_top_k = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Results")
leg_search_btn = gr.Button("πŸ” Search Legislation", variant="primary")
gr.Examples(
examples=[
"When can a nurse legally refuse to participate in a procedure?",
"What is the required staffing ratio for an intensive care unit?",
"Regulations regarding the administration of controlled drugs."
],
inputs=leg_search_input
)
leg_results = gr.Markdown("*Relevant legislation will appear here...*")
# Wire up search events
case_search_btn.click(
fn=search_cases,
inputs=[api_key_input, case_search_input, case_top_k],
outputs=[case_results, api_cost]
)
case_search_input.submit(
fn=search_cases,
inputs=[api_key_input, case_search_input, case_top_k],
outputs=[case_results, api_cost]
)
leg_search_btn.click(
fn=search_legislation,
inputs=[api_key_input, leg_search_input, leg_top_k],
outputs=[leg_results, api_cost]
)
leg_search_input.submit(
fn=search_legislation,
inputs=[api_key_input, leg_search_input, leg_top_k],
outputs=[leg_results, api_cost]
)
if __name__ == "__main__":
print("Starting NurseLex-Match...")
app.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)