Spaces:
Running
Running
| import os | |
| import json | |
| import re | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from llama_cpp import Llama | |
| # --- 1. CONFIGURATION --- | |
| REPO_ID = "st192011/llama-pharma-assistant" | |
| GGUF_FILENAME = "pharma_model_q4.gguf" | |
| DB_FILE = "ultimate_safety_vault.json" | |
| # --- 2. LOAD RESOURCES --- | |
| # Load Database | |
| try: | |
| with open(DB_FILE, 'r') as f: | |
| db = json.load(f) | |
| print(f"β Database loaded: {len(db)} generic entries.") | |
| except: | |
| print("β DB not found.") | |
| db = {} | |
| # Load Model | |
| print("β³ Loading Model...") | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=GGUF_FILENAME, | |
| token=os.environ.get("HF_TOKEN") | |
| ) | |
| llm = Llama( | |
| model_path=model_path, | |
| n_ctx=2048, | |
| n_threads=2, | |
| verbose=False | |
| ) | |
| print("β Model Ready.") | |
| except Exception as e: | |
| print(f"β Model Error: {e}") | |
| llm = None | |
| # --- 3. CORE LOGIC --- | |
| def find_drug_in_db(brand_query): | |
| if not brand_query: return None, None | |
| brand_query = brand_query.strip().upper() | |
| # 1. Check Generic Keys | |
| if brand_query in db: | |
| return brand_query, db[brand_query] | |
| # 2. Check Synonyms/Brands | |
| for generic, data in db.items(): | |
| if any(brand_query == b.upper() for b in data.get("eu_brands", [])): | |
| return generic, data | |
| return None, None | |
| def pipeline_generator(user_query): | |
| """ | |
| Yields 3 outputs: [Parser_JSON, DB_Code, Final_Markdown] | |
| """ | |
| if not llm: | |
| yield {}, "Error: Model not loaded.", "System Error." | |
| return | |
| # --- STEP 1: PARSING --- | |
| # Yield initial state | |
| yield {"status": "Parsing..."}, "Waiting for Parser...", "Waiting..." | |
| parser_messages = [ | |
| {"role": "user", "content": f"SYSTEM: You are a semantic parser. Extract intent and brand into JSON and JOSN only.\nUSER: {user_query}"} | |
| ] | |
| # Non-streaming call for JSON | |
| parser_out = llm.create_chat_completion(messages=parser_messages, max_tokens=60, temperature=0.01) | |
| raw_json = parser_out['choices'][0]['message']['content'] | |
| # Attempt extraction | |
| parsed_obj = {} | |
| try: | |
| match = re.search(r'\{.*\}', raw_json, re.DOTALL) | |
| if match: | |
| parsed_obj = json.loads(match.group(0)) | |
| else: | |
| yield {"raw": raw_json, "error": "No JSON found"}, "Pipeline Stopped.", "β οΈ I couldn't identify the drug name." | |
| return | |
| except: | |
| yield {"raw": raw_json, "error": "JSON Decode Error"}, "Pipeline Stopped.", "β οΈ Internal Parsing Error." | |
| return | |
| # --- STEP 2: DATABASE --- | |
| brand = parsed_obj.get("brand", "UNKNOWN") | |
| intent = parsed_obj.get("intent", "UNKNOWN") | |
| # Yield Parser Result immediately | |
| yield parsed_obj, f"Searching Vault for '{brand}'...", "..." | |
| generic_name, drug_data = find_drug_in_db(brand) | |
| if not drug_data: | |
| # Not Found Logic | |
| not_found_msg = f"ERROR: '{brand}' not found in Verified Vault." | |
| final_msg = f"π **Not Found:** I cannot provide information on **{brand}** because it is not in the verified safety database (Demo Limit: 10 Substances)." | |
| yield parsed_obj, not_found_msg, final_msg | |
| return | |
| # Extract Data Context | |
| context_text = "N/A" | |
| if intent == "GET_WARNING": | |
| context_text = drug_data['official_label'].get('boxed_warning', 'No boxed warning.') | |
| elif intent == "CHECK_CONTRA": | |
| context_text = drug_data['official_label'].get('contraindications', 'No contraindications.') | |
| elif intent == "GET_SIDE_EFFECTS": | |
| se = [f"{i['term']} ({i['count']})" for i in drug_data.get('side_effects', [])[:10]] | |
| context_text = ", ".join(se) | |
| elif intent == "CHECK_RELIABILITY": | |
| s = drug_data.get('reporter_stats', {}) | |
| context_text = f"Clinical Reports: {s.get('clinical_reports')}\nConsumer Reports: {s.get('consumer_reports')}" | |
| elif intent == "GET_CLASS": | |
| context_text = drug_data.get('generic', 'Unknown') | |
| # Format Database Output for Display | |
| db_display = f"GENERIC_KEY: {generic_name}\nINTENT_MATCH: {intent}\n\n[EXTRACTED_DATA]:\n{context_text}" | |
| # Yield DB Result immediately | |
| yield parsed_obj, db_display, "Generating Summary..." | |
| # --- STEP 3: GENERATION (Streaming) --- | |
| input_str = f"Databank found for {brand} (Generic: {generic_name}): {context_text}" | |
| assistant_messages = [ | |
| {"role": "user", "content": f"SYSTEM: You are a medical assistant. Answer based on context.\nUSER: {user_query}\nINPUT: {input_str}"} | |
| ] | |
| stream = llm.create_chat_completion( | |
| messages=assistant_messages, | |
| max_tokens=250, | |
| temperature=0.01, | |
| stream=True | |
| ) | |
| partial_ans = "" | |
| for chunk in stream: | |
| delta = chunk['choices'][0]['delta'] | |
| if 'content' in delta: | |
| partial_ans += delta['content'] | |
| # Yield ALL 3 outputs every time to keep UI consistent | |
| yield parsed_obj, db_display, partial_ans | |
| # --- 4. UI LAYOUT --- | |
| with gr.Blocks(theme=gr.themes.Soft(), title="PharmaVault Research Demo") as demo: | |
| gr.Markdown("# π‘οΈ PharmaVault AI: Verified Drug Safety Reasoning Engine") | |
| # --- RESEARCH INFO ACCORDION --- | |
| with gr.Accordion("π Technical Report & Database Coverage", open=False): | |
| gr.Markdown( | |
| """ | |
| # π‘οΈ PharmaVault AI: Verified Drug Safety Reasoning Engine | |
| **Author:** st192011 | **Framework:** Llama-3.1-8B + Unsloth + GGUF | |
| ### 1. Abstract: Solving Parametric Hallucinations | |
| Medical LLMs often suffer from **"parametric hallucinations,"** confidently inventing drug facts based on outdated weights. PharmaVault AI addresses this by **decoupling Reasoning from Knowledge**. | |
| The model does not "know" the facts; it acts strictly as a **Reasoning Agent** that parses user intent, queries a verified 4-source database (RxNorm, openFDA, SIDER, EMA), and summarizes the retrieved "Source of Truth." | |
| ### 2. Methodology & Training | |
| To achieve this multi-task capability, the model was fine-tuned on a balanced dataset of **10,000 samples**: | |
| * **Task A (Semantic Parser - 5k):** Instruction-tuning to convert natural language into JSON keys (e.g., `{"brand": "WARFARIN"}`). | |
| * **Task B (Verified Responder - 5k):** Training the model to synthesize context while ignoring internal pre-existing knowledge. | |
| * **Specs:** Trained via Unsloth (LoRA, Rank 16) and exported to **GGUF (Q4_K_M)** for efficient CPU inference. | |
| ### 3. The 3-Stage Gated Pipeline | |
| The application implements a strict logic flow to ensure safety: | |
| 1. π§ **Stage 1 (The Brain):** The Parser extracts the `Brand` and `Intent`. | |
| 2. π **Stage 2 (The Guardrail):** A Python-based lookup verifies the Brand against the Vault. **If not found, the process terminates.** | |
| 3. π©ββοΈ **Stage 3 (The Voice):** If verified, the model summarizes the data. | |
| ### 4. Logic Branching | |
| | Situation | Trigger | Outcome | | |
| | :--- | :--- | :--- | | |
| | **Success** | Brand found in vault | AI Summary + Raw JSON displayed. | | |
| | **Partial Fail** | Brand parsed but not in DB | Stops. Returns "Substance not covered". | | |
| | **Critical Fail** | Brand not identified | Stops. Returns "Please rephrase". | | |
| ### 5. Current Database Coverage | |
| The demonstration database contains **10 generic substances**, but via the synonym mapping layer, it covers over **2,500 global brands**: | |
| * **Metformin:** 869 brands | |
| * **Atorvastatin:** 563 brands | |
| * **Rivaroxaban:** 336 brands | |
| * **Lisinopril:** 253 brands | |
| * **Insulin:** 187 brands | |
| * **Methotrexate:** 126 brands | |
| * **Amiodarone:** 86 brands | |
| * **Phenytoin:** 54 brands | |
| * **Digoxin:** 52 brands | |
| * **Warfarin:** 27 brands | |
| """ | |
| ) | |
| # --- MAIN INTERFACE --- | |
| with gr.Row(): | |
| input_box = gr.Textbox( | |
| label="Patient Query", | |
| placeholder="e.g., What is the black box warning for Warfarin?", | |
| lines=2, | |
| scale=4 | |
| ) | |
| btn = gr.Button("Analyze", variant="primary", scale=1) | |
| # ROW 1: INTERMEDIATE STEPS | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π§ 1. Semantic Parser Output") | |
| out_parser = gr.JSON(label="Extracted Intent") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π 2. Database Guardrail Output") | |
| out_db = gr.Code(label="Verified Content (Source of Truth)", language="yaml", interactive=False) | |
| # ROW 2: FINAL RESULT | |
| gr.Markdown("### π©ββοΈ 3. Final Medical Summary") | |
| out_final = gr.Markdown(value="Waiting for input...") | |
| # --- EVENT HANDLER --- | |
| btn.click( | |
| fn=pipeline_generator, | |
| inputs=[input_box], | |
| outputs=[out_parser, out_db, out_final] | |
| ) | |
| # Example Cache | |
| gr.Examples( | |
| examples=[ | |
| ["What is the black box warning for Warfarin?"], | |
| ["Is Zentridol an antibiotic? (Ghost Drug Test)"], | |
| ["Give me the side effects of Nafordyl."], | |
| ["Is Metotrexato Accord safe for children?"] | |
| ], | |
| inputs=input_box | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |