import os import base64 import pickle import numpy as np import gradio as gr from perplexity import Perplexity import json # ─── Configuration ─────────────────────────────────────────────── MODEL_NAME = "pplx-embed-v1-0.6b" DB_FILE = "case_embeddings.pkl" # ─── Load Vector Database ──────────────────────────────────────── print(f"Loading vector database from {DB_FILE}...") if not os.path.exists(DB_FILE): raise FileNotFoundError( f"{DB_FILE} not found. Please run create_case_db.py first! " "(Note: This takes ~12 hours for the full dataset)." ) with open(DB_FILE, 'rb') as f: database = pickle.load(f) cases = database["cases"] legislation = database["legislation"] mapping = database["mapping"] corpus_embeddings = database["embeddings"] # Normalize corpus embeddings for cosine similarity corpus_norms = np.linalg.norm(corpus_embeddings, axis=1, keepdims=True) corpus_embeddings_normalized = corpus_embeddings / np.maximum(corpus_norms, 1e-8) # Pre-calculate indices for each type for faster filtering case_indices = [i for i, m in enumerate(mapping) if m["type"] == "case"] leg_indices = [i for i, m in enumerate(mapping) if m["type"] == "legislation"] case_embeddings = corpus_embeddings_normalized[case_indices] leg_embeddings = corpus_embeddings_normalized[leg_indices] print(f"✅ Loaded {len(cases)} clinical cases and {len(legislation)} legislation sections.") def decode_embedding(b64_string): """Decode a base64-encoded int8 embedding to float32.""" return np.frombuffer(base64.b64decode(b64_string), dtype=np.int8).astype(np.float32) def format_nursing_reasoning(reasoning): """Format the nursing_reasoning dictionary into readable markdown.""" if not isinstance(reasoning, dict): return str(reasoning) output = "" if "Differential Diagnosis" in reasoning: output += f"**Differential Diagnosis:**\n{reasoning['Differential Diagnosis']}\n\n" if "Diagnostic Tests" in reasoning: output += f"**Diagnostic Tests:**\n{reasoning['Diagnostic Tests']}\n\n" if "Management Plan" in reasoning: output += f"**Management Plan:**\n{reasoning['Management Plan']}\n\n" if "Underlying Mechanism" in reasoning: output += f"**Underlying Mechanism:**\n{reasoning['Underlying Mechanism']}\n\n" if "Relevant Medications" in reasoning: output += f"**Relevant Medications:**\n{reasoning['Relevant Medications']}\n\n" return output def perform_search(api_key, query, top_k, search_type): """Perform actual semantic search using the Perplexity API.""" if not api_key or not api_key.strip(): return "⚠️ **Please enter your Perplexity API Key** above.", "" if not query or not query.strip(): return "Please enter a clinical presentation or legal query.", "" try: # Embed the query client = Perplexity(api_key=api_key.strip()) response = client.embeddings.create( input=[query], model=MODEL_NAME ) query_embedding = decode_embedding(response.data[0].embedding) # Normalize query query_norm = np.linalg.norm(query_embedding) query_embedding_normalized = query_embedding / max(query_norm, 1e-8) output = f"### 🔍 Results for: *\"{query}\"*\n\n---\n" cost_str = "" if hasattr(response, 'usage') and response.usage: cost_str = f"API Cost: ${response.usage.cost.total_cost:.6f}" if search_type == "case": # Match against Clinical Cases similarities = np.dot(case_embeddings, query_embedding_normalized) top_k_indices = np.argsort(similarities)[-top_k:][::-1] for i, local_idx in enumerate(top_k_indices): score = similarities[local_idx] global_idx = case_indices[local_idx] case_item = cases[mapping[global_idx]["source_idx"]] vignette = case_item.get("original_vignette", "") reasoning = case_item.get("nursing_reasoning", {}) case_id = case_item.get("id", "Unknown") output += f"#### Result {i+1} — Relevance: {score:.3f} (Case #{case_id})\n" output += f"**Patient Presentation:**\n> {vignette}\n\n" output += f"
View Nursing Reasoning\n\n" output += f"{format_nursing_reasoning(reasoning)}\n
\n\n---\n" else: # Match against Legislation similarities = np.dot(leg_embeddings, query_embedding_normalized) top_k_indices = np.argsort(similarities)[-top_k:][::-1] for i, local_idx in enumerate(top_k_indices): score = similarities[local_idx] global_idx = leg_indices[local_idx] leg_item = legislation[mapping[global_idx]["source_idx"]] title = leg_item.get("title", "") leg_id = leg_item.get("legislation_id", "") text = leg_item.get("text", "") output += f"#### Result {i+1} — Relevance: {score:.3f} (Act: {leg_id})\n" output += f"⚖️ **{title}**\n\n" output += f"{text}\n\n---\n" return output, cost_str except Exception as e: error_msg = str(e) if "401" in error_msg or "auth" in error_msg.lower(): return "❌ **Invalid API Key.** Check your key and try again.", "" return f"❌ Error: {error_msg}", "" def search_cases(api_key, query, top_k): return perform_search(api_key, query, top_k, search_type="case") def search_legislation(api_key, query, top_k): return perform_search(api_key, query, top_k, search_type="legislation") # ─── Gradio UI ─────────────────────────────────────────────────── with gr.Blocks(title="NurseLex-Match") as app: gr.Markdown( """ # 🩺 NurseLex-Match ### Clinical Case Similarity & Legal Lookup *Powered by Perplexity Embeddings (`pplx-embed-v1-0.6b`)* Retrieve similar historical cases from the **NurseReason-Dataset** or search UK nursing law in **NurseLex**. """ ) with gr.Accordion("🔑 API Key (BYOK — Bring Your Own Key)", open=True): gr.Markdown( "Your key is **never stored** — it is used only for this session to query the embeddings API. " ) api_key_input = gr.Textbox( label="Perplexity API Key", placeholder="pplx-...", type="password", lines=1 ) api_cost = gr.Markdown("*API Cost: $0.000000*") with gr.Tabs(): # TAB 1: Clinical Cases with gr.TabItem("🏥 Clinical Case Matcher"): with gr.Row(): with gr.Column(scale=3): case_search_input = gr.Textbox( label="Describe the patient presentation:", placeholder="e.g., 72-year-old presenting with acute confusion and suspected UTI...", lines=3 ) with gr.Column(scale=1): case_top_k = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Results") case_search_btn = gr.Button("🔍 Find Similar Cases", variant="primary") gr.Examples( examples=[ "Patient is a 45-year-old female complaining of sharp left lower quadrant abdominal pain radiating to the back.", "Elderly patient with a history of heart failure presenting with increased shortness of breath and pitting edema.", "Post-operative patient day 2 showing signs of infection at the wound site with low-grade fever." ], inputs=case_search_input ) case_results = gr.Markdown("*Similar cases will appear here...*") # TAB 2: Nursing Legislation with gr.TabItem("⚖️ Legal Lookup"): with gr.Row(): with gr.Column(scale=3): leg_search_input = gr.Textbox( label="Describe the legal context or policy question:", placeholder="e.g., What are the rules regarding compulsory admission under the Mental Health Act?", lines=3 ) with gr.Column(scale=1): leg_top_k = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Results") leg_search_btn = gr.Button("🔍 Search Legislation", variant="primary") gr.Examples( examples=[ "When can a nurse legally refuse to participate in a procedure?", "What is the required staffing ratio for an intensive care unit?", "Regulations regarding the administration of controlled drugs." ], inputs=leg_search_input ) leg_results = gr.Markdown("*Relevant legislation will appear here...*") # Wire up search events case_search_btn.click( fn=search_cases, inputs=[api_key_input, case_search_input, case_top_k], outputs=[case_results, api_cost] ) case_search_input.submit( fn=search_cases, inputs=[api_key_input, case_search_input, case_top_k], outputs=[case_results, api_cost] ) leg_search_btn.click( fn=search_legislation, inputs=[api_key_input, leg_search_input, leg_top_k], outputs=[leg_results, api_cost] ) leg_search_input.submit( fn=search_legislation, inputs=[api_key_input, leg_search_input, leg_top_k], outputs=[leg_results, api_cost] ) if __name__ == "__main__": print("Starting NurseLex-Match...") app.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)