Pharma_Vault / app.py
st192011's picture
Update app.py
e707546 verified
import os
import json
import re
import gradio as gr
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
# --- 1. CONFIGURATION ---
REPO_ID = "st192011/llama-pharma-assistant"
GGUF_FILENAME = "pharma_model_q4.gguf"
DB_FILE = "ultimate_safety_vault.json"
# --- 2. LOAD RESOURCES ---
# Load Database
try:
with open(DB_FILE, 'r') as f:
db = json.load(f)
print(f"βœ… Database loaded: {len(db)} generic entries.")
except:
print("❌ DB not found.")
db = {}
# Load Model
print("⏳ Loading Model...")
try:
model_path = hf_hub_download(
repo_id=REPO_ID,
filename=GGUF_FILENAME,
token=os.environ.get("HF_TOKEN")
)
llm = Llama(
model_path=model_path,
n_ctx=2048,
n_threads=2,
verbose=False
)
print("βœ… Model Ready.")
except Exception as e:
print(f"❌ Model Error: {e}")
llm = None
# --- 3. CORE LOGIC ---
def find_drug_in_db(brand_query):
if not brand_query: return None, None
brand_query = brand_query.strip().upper()
# 1. Check Generic Keys
if brand_query in db:
return brand_query, db[brand_query]
# 2. Check Synonyms/Brands
for generic, data in db.items():
if any(brand_query == b.upper() for b in data.get("eu_brands", [])):
return generic, data
return None, None
def pipeline_generator(user_query):
"""
Yields 3 outputs: [Parser_JSON, DB_Code, Final_Markdown]
"""
if not llm:
yield {}, "Error: Model not loaded.", "System Error."
return
# --- STEP 1: PARSING ---
# Yield initial state
yield {"status": "Parsing..."}, "Waiting for Parser...", "Waiting..."
parser_messages = [
{"role": "user", "content": f"SYSTEM: You are a semantic parser. Extract intent and brand into JSON and JOSN only.\nUSER: {user_query}"}
]
# Non-streaming call for JSON
parser_out = llm.create_chat_completion(messages=parser_messages, max_tokens=60, temperature=0.01)
raw_json = parser_out['choices'][0]['message']['content']
# Attempt extraction
parsed_obj = {}
try:
match = re.search(r'\{.*\}', raw_json, re.DOTALL)
if match:
parsed_obj = json.loads(match.group(0))
else:
yield {"raw": raw_json, "error": "No JSON found"}, "Pipeline Stopped.", "⚠️ I couldn't identify the drug name."
return
except:
yield {"raw": raw_json, "error": "JSON Decode Error"}, "Pipeline Stopped.", "⚠️ Internal Parsing Error."
return
# --- STEP 2: DATABASE ---
brand = parsed_obj.get("brand", "UNKNOWN")
intent = parsed_obj.get("intent", "UNKNOWN")
# Yield Parser Result immediately
yield parsed_obj, f"Searching Vault for '{brand}'...", "..."
generic_name, drug_data = find_drug_in_db(brand)
if not drug_data:
# Not Found Logic
not_found_msg = f"ERROR: '{brand}' not found in Verified Vault."
final_msg = f"πŸ›‘ **Not Found:** I cannot provide information on **{brand}** because it is not in the verified safety database (Demo Limit: 10 Substances)."
yield parsed_obj, not_found_msg, final_msg
return
# Extract Data Context
context_text = "N/A"
if intent == "GET_WARNING":
context_text = drug_data['official_label'].get('boxed_warning', 'No boxed warning.')
elif intent == "CHECK_CONTRA":
context_text = drug_data['official_label'].get('contraindications', 'No contraindications.')
elif intent == "GET_SIDE_EFFECTS":
se = [f"{i['term']} ({i['count']})" for i in drug_data.get('side_effects', [])[:10]]
context_text = ", ".join(se)
elif intent == "CHECK_RELIABILITY":
s = drug_data.get('reporter_stats', {})
context_text = f"Clinical Reports: {s.get('clinical_reports')}\nConsumer Reports: {s.get('consumer_reports')}"
elif intent == "GET_CLASS":
context_text = drug_data.get('generic', 'Unknown')
# Format Database Output for Display
db_display = f"GENERIC_KEY: {generic_name}\nINTENT_MATCH: {intent}\n\n[EXTRACTED_DATA]:\n{context_text}"
# Yield DB Result immediately
yield parsed_obj, db_display, "Generating Summary..."
# --- STEP 3: GENERATION (Streaming) ---
input_str = f"Databank found for {brand} (Generic: {generic_name}): {context_text}"
assistant_messages = [
{"role": "user", "content": f"SYSTEM: You are a medical assistant. Answer based on context.\nUSER: {user_query}\nINPUT: {input_str}"}
]
stream = llm.create_chat_completion(
messages=assistant_messages,
max_tokens=250,
temperature=0.01,
stream=True
)
partial_ans = ""
for chunk in stream:
delta = chunk['choices'][0]['delta']
if 'content' in delta:
partial_ans += delta['content']
# Yield ALL 3 outputs every time to keep UI consistent
yield parsed_obj, db_display, partial_ans
# --- 4. UI LAYOUT ---
with gr.Blocks(theme=gr.themes.Soft(), title="PharmaVault Research Demo") as demo:
gr.Markdown("# πŸ›‘οΈ PharmaVault AI: Verified Drug Safety Reasoning Engine")
# --- RESEARCH INFO ACCORDION ---
with gr.Accordion("πŸ“š Technical Report & Database Coverage", open=False):
gr.Markdown(
"""
# πŸ›‘οΈ PharmaVault AI: Verified Drug Safety Reasoning Engine
**Author:** st192011 | **Framework:** Llama-3.1-8B + Unsloth + GGUF
### 1. Abstract: Solving Parametric Hallucinations
Medical LLMs often suffer from **"parametric hallucinations,"** confidently inventing drug facts based on outdated weights. PharmaVault AI addresses this by **decoupling Reasoning from Knowledge**.
The model does not "know" the facts; it acts strictly as a **Reasoning Agent** that parses user intent, queries a verified 4-source database (RxNorm, openFDA, SIDER, EMA), and summarizes the retrieved "Source of Truth."
### 2. Methodology & Training
To achieve this multi-task capability, the model was fine-tuned on a balanced dataset of **10,000 samples**:
* **Task A (Semantic Parser - 5k):** Instruction-tuning to convert natural language into JSON keys (e.g., `{"brand": "WARFARIN"}`).
* **Task B (Verified Responder - 5k):** Training the model to synthesize context while ignoring internal pre-existing knowledge.
* **Specs:** Trained via Unsloth (LoRA, Rank 16) and exported to **GGUF (Q4_K_M)** for efficient CPU inference.
### 3. The 3-Stage Gated Pipeline
The application implements a strict logic flow to ensure safety:
1. 🧠 **Stage 1 (The Brain):** The Parser extracts the `Brand` and `Intent`.
2. πŸ”’ **Stage 2 (The Guardrail):** A Python-based lookup verifies the Brand against the Vault. **If not found, the process terminates.**
3. πŸ‘©β€βš•οΈ **Stage 3 (The Voice):** If verified, the model summarizes the data.
### 4. Logic Branching
| Situation | Trigger | Outcome |
| :--- | :--- | :--- |
| **Success** | Brand found in vault | AI Summary + Raw JSON displayed. |
| **Partial Fail** | Brand parsed but not in DB | Stops. Returns "Substance not covered". |
| **Critical Fail** | Brand not identified | Stops. Returns "Please rephrase". |
### 5. Current Database Coverage
The demonstration database contains **10 generic substances**, but via the synonym mapping layer, it covers over **2,500 global brands**:
* **Metformin:** 869 brands
* **Atorvastatin:** 563 brands
* **Rivaroxaban:** 336 brands
* **Lisinopril:** 253 brands
* **Insulin:** 187 brands
* **Methotrexate:** 126 brands
* **Amiodarone:** 86 brands
* **Phenytoin:** 54 brands
* **Digoxin:** 52 brands
* **Warfarin:** 27 brands
"""
)
# --- MAIN INTERFACE ---
with gr.Row():
input_box = gr.Textbox(
label="Patient Query",
placeholder="e.g., What is the black box warning for Warfarin?",
lines=2,
scale=4
)
btn = gr.Button("Analyze", variant="primary", scale=1)
# ROW 1: INTERMEDIATE STEPS
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 🧠 1. Semantic Parser Output")
out_parser = gr.JSON(label="Extracted Intent")
with gr.Column(scale=2):
gr.Markdown("### πŸ”’ 2. Database Guardrail Output")
out_db = gr.Code(label="Verified Content (Source of Truth)", language="yaml", interactive=False)
# ROW 2: FINAL RESULT
gr.Markdown("### πŸ‘©β€βš•οΈ 3. Final Medical Summary")
out_final = gr.Markdown(value="Waiting for input...")
# --- EVENT HANDLER ---
btn.click(
fn=pipeline_generator,
inputs=[input_box],
outputs=[out_parser, out_db, out_final]
)
# Example Cache
gr.Examples(
examples=[
["What is the black box warning for Warfarin?"],
["Is Zentridol an antibiotic? (Ghost Drug Test)"],
["Give me the side effects of Nafordyl."],
["Is Metotrexato Accord safe for children?"]
],
inputs=input_box
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)