import streamlit as st
import os
from src.bio_rag.pipeline import BioRAGPipeline
# --- Page Configuration ---
st.set_page_config(page_title="BioRAG Medical Assistant", page_icon="🏥", layout="wide")
# --- Load Custom CSS ---
css_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "style.css")
if os.path.exists(css_path):
with open(css_path) as f:
st.markdown(f"", unsafe_allow_html=True)
# --- Cached Pipeline Initialization ---
@st.cache_resource(show_spinner=False)
def load_pipeline():
"""Load the full RAG pipeline (this will also load vector stores and models)"""
return BioRAGPipeline()
# Initialize the pipeline silently behind the scenes
pipeline = load_pipeline()
# --- Sidebar ---
with st.sidebar:
st.markdown("""
🏥
BioRAG
Medical Hallucination Detector
""", unsafe_allow_html=True)
st.markdown("---")
st.markdown("""
Two-Phase Pipeline
① Phase 1: Retrieval & Generation
② Phase 2: Decompose into Claims
③ Phase 2: NLI Verification
④ Phase 2: Clinical Risk Scoring
""", unsafe_allow_html=True)
st.markdown("---")
st.markdown("""
Tech Stack
☁️ llama-3.1-8b-instant (Groq)
🛡️ nli-deberta-v3-base
🔢 FAISS Hybrid Retrieval
""", unsafe_allow_html=True)
st.markdown("---")
if st.button("🗑️ Clear Chat History"):
st.session_state.messages = []
st.rerun()
# --- Main App Header ---
st.markdown("""
🏥 Bio-RAG: Clinical Fact-Checking
Generates an answer and scores its risk of hallucination using NLI and Clinical Severity heuristics.
""", unsafe_allow_html=True)
st.markdown("---")
# --- Chat State Management ---
if "messages" not in st.session_state:
st.session_state.messages = []
# --- Render Chat History ---
for msg in st.session_state.messages:
if msg["role"] == "user":
with st.chat_message("user"):
st.markdown(msg["content"])
elif msg["role"] == "assistant":
with st.chat_message("assistant"):
st.markdown(msg["content"])
# Display Risk Badge if it's an assistant message and successfully scored
if "result_data" in msg:
res = msg["result_data"]
if res.get("rejection_message"):
pass # Handled in the markdown output already implicitly, but can add badge:
else:
max_risk = res.get("max_risk_score", 0.0)
is_safe = res.get("safe", False)
if is_safe:
st.markdown(f"✅ **Safe (Low Risk)**: Maximum Clinical Risk Score is **{max_risk:.4f}**")
else:
st.markdown(f"⚠️ **FLAGGED (High Risk)**: Maximum Clinical Risk Score is **{max_risk:.4f}**. Answer has been redacted.")
# Add an expander for the detailed claim breakdown
with st.expander("🔍 View Verification Details"):
st.markdown("### Atomic Claims & Risk Scores")
for claim_check in res.get("claim_checks", []):
risk_val = claim_check.get("risk_score", 0.0)
st.markdown(f"""
**Claim:** {claim_check.get('claim')}
- **NLI Contradiction Prob:** {claim_check.get('nli_prob')}
- **Risk Score: {risk_val:.4f}**
---
""")
if res.get("evidence"):
st.markdown("### Retrieved Context (Top Passages)")
for idx, ev in enumerate(res.get("evidence", [])[:3]):
text = ev.get('text', str(ev)) if isinstance(ev, dict) else (ev.text if hasattr(ev, 'text') else str(ev))
st.info(f"**Document {idx+1}:** {text}")
# --- Handle User Input ---
if prompt := st.chat_input("Ask a medical question about diabetes (e.g., 'Is high insulin dose safe for mild sugar elevation?')..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("🤖 Phase 1: Retrieving context & Generating answer via Groq..."):
# The spinner text updates are implicit, we just run the pipeline.
pass
with st.spinner("🛡️ Phase 2: Evaluating Claims & Calculating Clinical Risk (DeBERTa NLI)..."):
# Call the Pipeline
result = pipeline.ask(prompt)
answer_text = result.final_answer
st.markdown(answer_text)
if not result.rejection_message:
if result.safe:
st.success(f"✅ **Safe (Low Risk)**: Maximum Clinical Risk Score is **{result.max_risk_score:.4f}**")
else:
st.error(f"⚠️ **FLAGGED (High Risk)**: Maximum Clinical Risk Score is **{result.max_risk_score:.4f}**. Answer has been redacted.")
with st.expander("🔍 View Verification Details"):
st.markdown("### Atomic Claims & Risk Scores")
for claim_check in result.claim_checks:
risk_val = claim_check.get('risk_score', 0.0)
st.markdown(f"""
**Claim:** {claim_check.get('claim')}
- **NLI Contradiction Prob:** {claim_check.get('nli_prob')}
- **Risk Score: {risk_val:.4f}**
---
""")
if result.evidence:
st.markdown("### Retrieved Context (Top Passages)")
for idx, ev in enumerate(result.evidence[:3]):
text = ev.get('text', str(ev)) if isinstance(ev, dict) else (ev.text if hasattr(ev, 'text') else str(ev))
st.info(f"**Document {idx+1}:** {text}")
# Save assistant message to state with result data
# We need to make sure result.evidence is properly serialized or ignored to avoid st.session_state issues.
# result.to_dict() is safe as long as it handles RetrievedPassage correctly.
st.session_state.messages.append({
"role": "assistant",
"content": answer_text,
"result_data": result.to_dict()
})