import os import json import re import gradio as gr from huggingface_hub import hf_hub_download from llama_cpp import Llama # --- 1. CONFIGURATION --- REPO_ID = "st192011/llama-pharma-assistant" GGUF_FILENAME = "pharma_model_q4.gguf" DB_FILE = "ultimate_safety_vault.json" # --- 2. LOAD RESOURCES --- # Load Database try: with open(DB_FILE, 'r') as f: db = json.load(f) print(f"✅ Database loaded: {len(db)} generic entries.") except: print("❌ DB not found.") db = {} # Load Model print("⏳ Loading Model...") try: model_path = hf_hub_download( repo_id=REPO_ID, filename=GGUF_FILENAME, token=os.environ.get("HF_TOKEN") ) llm = Llama( model_path=model_path, n_ctx=2048, n_threads=2, verbose=False ) print("✅ Model Ready.") except Exception as e: print(f"❌ Model Error: {e}") llm = None # --- 3. CORE LOGIC --- def find_drug_in_db(brand_query): if not brand_query: return None, None brand_query = brand_query.strip().upper() # 1. Check Generic Keys if brand_query in db: return brand_query, db[brand_query] # 2. Check Synonyms/Brands for generic, data in db.items(): if any(brand_query == b.upper() for b in data.get("eu_brands", [])): return generic, data return None, None def pipeline_generator(user_query): """ Yields 3 outputs: [Parser_JSON, DB_Code, Final_Markdown] """ if not llm: yield {}, "Error: Model not loaded.", "System Error." return # --- STEP 1: PARSING --- # Yield initial state yield {"status": "Parsing..."}, "Waiting for Parser...", "Waiting..." parser_messages = [ {"role": "user", "content": f"SYSTEM: You are a semantic parser. Extract intent and brand into JSON and JOSN only.\nUSER: {user_query}"} ] # Non-streaming call for JSON parser_out = llm.create_chat_completion(messages=parser_messages, max_tokens=60, temperature=0.01) raw_json = parser_out['choices'][0]['message']['content'] # Attempt extraction parsed_obj = {} try: match = re.search(r'\{.*\}', raw_json, re.DOTALL) if match: parsed_obj = json.loads(match.group(0)) else: yield {"raw": raw_json, "error": "No JSON found"}, "Pipeline Stopped.", "⚠️ I couldn't identify the drug name." return except: yield {"raw": raw_json, "error": "JSON Decode Error"}, "Pipeline Stopped.", "⚠️ Internal Parsing Error." return # --- STEP 2: DATABASE --- brand = parsed_obj.get("brand", "UNKNOWN") intent = parsed_obj.get("intent", "UNKNOWN") # Yield Parser Result immediately yield parsed_obj, f"Searching Vault for '{brand}'...", "..." generic_name, drug_data = find_drug_in_db(brand) if not drug_data: # Not Found Logic not_found_msg = f"ERROR: '{brand}' not found in Verified Vault." final_msg = f"🛑 **Not Found:** I cannot provide information on **{brand}** because it is not in the verified safety database (Demo Limit: 10 Substances)." yield parsed_obj, not_found_msg, final_msg return # Extract Data Context context_text = "N/A" if intent == "GET_WARNING": context_text = drug_data['official_label'].get('boxed_warning', 'No boxed warning.') elif intent == "CHECK_CONTRA": context_text = drug_data['official_label'].get('contraindications', 'No contraindications.') elif intent == "GET_SIDE_EFFECTS": se = [f"{i['term']} ({i['count']})" for i in drug_data.get('side_effects', [])[:10]] context_text = ", ".join(se) elif intent == "CHECK_RELIABILITY": s = drug_data.get('reporter_stats', {}) context_text = f"Clinical Reports: {s.get('clinical_reports')}\nConsumer Reports: {s.get('consumer_reports')}" elif intent == "GET_CLASS": context_text = drug_data.get('generic', 'Unknown') # Format Database Output for Display db_display = f"GENERIC_KEY: {generic_name}\nINTENT_MATCH: {intent}\n\n[EXTRACTED_DATA]:\n{context_text}" # Yield DB Result immediately yield parsed_obj, db_display, "Generating Summary..." # --- STEP 3: GENERATION (Streaming) --- input_str = f"Databank found for {brand} (Generic: {generic_name}): {context_text}" assistant_messages = [ {"role": "user", "content": f"SYSTEM: You are a medical assistant. Answer based on context.\nUSER: {user_query}\nINPUT: {input_str}"} ] stream = llm.create_chat_completion( messages=assistant_messages, max_tokens=250, temperature=0.01, stream=True ) partial_ans = "" for chunk in stream: delta = chunk['choices'][0]['delta'] if 'content' in delta: partial_ans += delta['content'] # Yield ALL 3 outputs every time to keep UI consistent yield parsed_obj, db_display, partial_ans # --- 4. UI LAYOUT --- with gr.Blocks(theme=gr.themes.Soft(), title="PharmaVault Research Demo") as demo: gr.Markdown("# 🛡️ PharmaVault AI: Verified Drug Safety Reasoning Engine") # --- RESEARCH INFO ACCORDION --- with gr.Accordion("📚 Technical Report & Database Coverage", open=False): gr.Markdown( """ # 🛡️ PharmaVault AI: Verified Drug Safety Reasoning Engine **Author:** st192011 | **Framework:** Llama-3.1-8B + Unsloth + GGUF ### 1. Abstract: Solving Parametric Hallucinations Medical LLMs often suffer from **"parametric hallucinations,"** confidently inventing drug facts based on outdated weights. PharmaVault AI addresses this by **decoupling Reasoning from Knowledge**. The model does not "know" the facts; it acts strictly as a **Reasoning Agent** that parses user intent, queries a verified 4-source database (RxNorm, openFDA, SIDER, EMA), and summarizes the retrieved "Source of Truth." ### 2. Methodology & Training To achieve this multi-task capability, the model was fine-tuned on a balanced dataset of **10,000 samples**: * **Task A (Semantic Parser - 5k):** Instruction-tuning to convert natural language into JSON keys (e.g., `{"brand": "WARFARIN"}`). * **Task B (Verified Responder - 5k):** Training the model to synthesize context while ignoring internal pre-existing knowledge. * **Specs:** Trained via Unsloth (LoRA, Rank 16) and exported to **GGUF (Q4_K_M)** for efficient CPU inference. ### 3. The 3-Stage Gated Pipeline The application implements a strict logic flow to ensure safety: 1. 🧠 **Stage 1 (The Brain):** The Parser extracts the `Brand` and `Intent`. 2. 🔒 **Stage 2 (The Guardrail):** A Python-based lookup verifies the Brand against the Vault. **If not found, the process terminates.** 3. 👩‍⚕️ **Stage 3 (The Voice):** If verified, the model summarizes the data. ### 4. Logic Branching | Situation | Trigger | Outcome | | :--- | :--- | :--- | | **Success** | Brand found in vault | AI Summary + Raw JSON displayed. | | **Partial Fail** | Brand parsed but not in DB | Stops. Returns "Substance not covered". | | **Critical Fail** | Brand not identified | Stops. Returns "Please rephrase". | ### 5. Current Database Coverage The demonstration database contains **10 generic substances**, but via the synonym mapping layer, it covers over **2,500 global brands**: * **Metformin:** 869 brands * **Atorvastatin:** 563 brands * **Rivaroxaban:** 336 brands * **Lisinopril:** 253 brands * **Insulin:** 187 brands * **Methotrexate:** 126 brands * **Amiodarone:** 86 brands * **Phenytoin:** 54 brands * **Digoxin:** 52 brands * **Warfarin:** 27 brands """ ) # --- MAIN INTERFACE --- with gr.Row(): input_box = gr.Textbox( label="Patient Query", placeholder="e.g., What is the black box warning for Warfarin?", lines=2, scale=4 ) btn = gr.Button("Analyze", variant="primary", scale=1) # ROW 1: INTERMEDIATE STEPS with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 🧠 1. Semantic Parser Output") out_parser = gr.JSON(label="Extracted Intent") with gr.Column(scale=2): gr.Markdown("### 🔒 2. Database Guardrail Output") out_db = gr.Code(label="Verified Content (Source of Truth)", language="yaml", interactive=False) # ROW 2: FINAL RESULT gr.Markdown("### 👩‍⚕️ 3. Final Medical Summary") out_final = gr.Markdown(value="Waiting for input...") # --- EVENT HANDLER --- btn.click( fn=pipeline_generator, inputs=[input_box], outputs=[out_parser, out_db, out_final] ) # Example Cache gr.Examples( examples=[ ["What is the black box warning for Warfarin?"], ["Is Zentridol an antibiotic? (Ghost Drug Test)"], ["Give me the side effects of Nafordyl."], ["Is Metotrexato Accord safe for children?"] ], inputs=input_box ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)