BioRAG / app.py
aseelflihan's picture
Deploy Bio-RAG
2a2c039
import streamlit as st
import os
from src.bio_rag.pipeline import BioRAGPipeline
# --- Page Configuration ---
st.set_page_config(page_title="BioRAG Medical Assistant", page_icon="πŸ₯", layout="wide")
# --- Load Custom CSS ---
css_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "style.css")
if os.path.exists(css_path):
with open(css_path) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
# --- Cached Pipeline Initialization ---
@st.cache_resource(show_spinner=False)
def load_pipeline():
"""Load the full RAG pipeline (this will also load vector stores and models)"""
return BioRAGPipeline()
# Initialize the pipeline silently behind the scenes
pipeline = load_pipeline()
# --- Sidebar ---
with st.sidebar:
st.markdown("""
<div style="text-align:center; padding: 1rem 0 0.5rem;">
<div style="font-size: 2.5rem;">πŸ₯</div>
<div style="font-size: 1.3rem; font-weight: 700; color: #1e293b; margin-top: 0.3rem;">BioRAG</div>
<div style="font-size: 0.8rem; color: #64748b;">Medical Hallucination Detector</div>
</div>
""", unsafe_allow_html=True)
st.markdown("---")
st.markdown("""
<div style="padding: 0.6rem 0;">
<div style="font-size: 0.75rem; color: #64748b; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 0.5rem;">Two-Phase Pipeline</div>
<div style="color: #334155; font-size: 0.85rem; line-height: 2;">
<span style="color: #2563eb;">β‘ </span> <b>Phase 1:</b> Retrieval & Generation<br>
<span style="color: #0d9488;">β‘‘</span> <b>Phase 2:</b> Decompose into Claims<br>
<span style="color: #d97706;">β‘’</span> <b>Phase 2:</b> NLI Verification<br>
<span style="color: #dc2626;">β‘£</span> <b>Phase 2:</b> Clinical Risk Scoring
</div>
</div>
""", unsafe_allow_html=True)
st.markdown("---")
st.markdown("""
<div style="padding: 0.4rem 0;">
<div style="font-size: 0.75rem; color: #64748b; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 0.5rem;">Tech Stack</div>
<div style="color: #475569; font-size: 0.78rem; line-height: 1.9;">
☁️ <span style="color: #7c3aed;">llama-3.1-8b-instant (Groq)</span><br>
πŸ›‘οΈ <span style="color: #059669;">nli-deberta-v3-base</span><br>
πŸ”’ <span style="color: #2563eb;">FAISS Hybrid Retrieval</span>
</div>
</div>
""", unsafe_allow_html=True)
st.markdown("---")
if st.button("πŸ—‘οΈ Clear Chat History"):
st.session_state.messages = []
st.rerun()
# --- Main App Header ---
st.markdown("""
<div style="padding: 0.5rem 0 0.3rem;">
<h1 style="color: #1e293b; font-size: 1.6rem; margin-bottom: 0.2rem;">πŸ₯ Bio-RAG: Clinical Fact-Checking</h1>
<p style="color: #64748b; font-size: 0.88rem; margin: 0;">Generates an answer and scores its risk of hallucination using NLI and Clinical Severity heuristics.</p>
</div>
""", unsafe_allow_html=True)
st.markdown("---")
# --- Chat State Management ---
if "messages" not in st.session_state:
st.session_state.messages = []
# --- Render Chat History ---
for msg in st.session_state.messages:
if msg["role"] == "user":
with st.chat_message("user"):
st.markdown(msg["content"])
elif msg["role"] == "assistant":
with st.chat_message("assistant"):
st.markdown(msg["content"])
# Display Risk Badge if it's an assistant message and successfully scored
if "result_data" in msg:
res = msg["result_data"]
if res.get("rejection_message"):
pass # Handled in the markdown output already implicitly, but can add badge:
else:
max_risk = res.get("max_risk_score", 0.0)
is_safe = res.get("safe", False)
if is_safe:
st.markdown(f"βœ… **Safe (Low Risk)**: Maximum Clinical Risk Score is **{max_risk:.4f}**")
else:
st.markdown(f"⚠️ **FLAGGED (High Risk)**: Maximum Clinical Risk Score is **{max_risk:.4f}**. Answer has been redacted.")
# Add an expander for the detailed claim breakdown
with st.expander("πŸ” View Verification Details"):
st.markdown("### Atomic Claims & Risk Scores")
for claim_check in res.get("claim_checks", []):
risk_val = claim_check.get("risk_score", 0.0)
st.markdown(f"""
**Claim:** {claim_check.get('claim')}
- **NLI Contradiction Prob:** {claim_check.get('nli_prob')}
- **Risk Score: {risk_val:.4f}**
---
""")
if res.get("evidence"):
st.markdown("### Retrieved Context (Top Passages)")
for idx, ev in enumerate(res.get("evidence", [])[:3]):
text = ev.get('text', str(ev)) if isinstance(ev, dict) else (ev.text if hasattr(ev, 'text') else str(ev))
st.info(f"**Document {idx+1}:** {text}")
# --- Handle User Input ---
if prompt := st.chat_input("Ask a medical question about diabetes (e.g., 'Is high insulin dose safe for mild sugar elevation?')..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("πŸ€– Phase 1: Retrieving context & Generating answer via Groq..."):
# The spinner text updates are implicit, we just run the pipeline.
pass
with st.spinner("πŸ›‘οΈ Phase 2: Evaluating Claims & Calculating Clinical Risk (DeBERTa NLI)..."):
# Call the Pipeline
result = pipeline.ask(prompt)
answer_text = result.final_answer
st.markdown(answer_text)
if not result.rejection_message:
if result.safe:
st.success(f"βœ… **Safe (Low Risk)**: Maximum Clinical Risk Score is **{result.max_risk_score:.4f}**")
else:
st.error(f"⚠️ **FLAGGED (High Risk)**: Maximum Clinical Risk Score is **{result.max_risk_score:.4f}**. Answer has been redacted.")
with st.expander("πŸ” View Verification Details"):
st.markdown("### Atomic Claims & Risk Scores")
for claim_check in result.claim_checks:
risk_val = claim_check.get('risk_score', 0.0)
st.markdown(f"""
**Claim:** {claim_check.get('claim')}
- **NLI Contradiction Prob:** {claim_check.get('nli_prob')}
- **Risk Score: {risk_val:.4f}**
---
""")
if result.evidence:
st.markdown("### Retrieved Context (Top Passages)")
for idx, ev in enumerate(result.evidence[:3]):
text = ev.get('text', str(ev)) if isinstance(ev, dict) else (ev.text if hasattr(ev, 'text') else str(ev))
st.info(f"**Document {idx+1}:** {text}")
# Save assistant message to state with result data
# We need to make sure result.evidence is properly serialized or ignored to avoid st.session_state issues.
# result.to_dict() is safe as long as it handles RetrievedPassage correctly.
st.session_state.messages.append({
"role": "assistant",
"content": answer_text,
"result_data": result.to_dict()
})