Spaces:
Running
Running
File size: 8,156 Bytes
2a2c039 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | import streamlit as st
import os
from src.bio_rag.pipeline import BioRAGPipeline
# --- Page Configuration ---
st.set_page_config(page_title="BioRAG Medical Assistant", page_icon="π₯", layout="wide")
# --- Load Custom CSS ---
css_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "style.css")
if os.path.exists(css_path):
with open(css_path) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
# --- Cached Pipeline Initialization ---
@st.cache_resource(show_spinner=False)
def load_pipeline():
"""Load the full RAG pipeline (this will also load vector stores and models)"""
return BioRAGPipeline()
# Initialize the pipeline silently behind the scenes
pipeline = load_pipeline()
# --- Sidebar ---
with st.sidebar:
st.markdown("""
<div style="text-align:center; padding: 1rem 0 0.5rem;">
<div style="font-size: 2.5rem;">π₯</div>
<div style="font-size: 1.3rem; font-weight: 700; color: #1e293b; margin-top: 0.3rem;">BioRAG</div>
<div style="font-size: 0.8rem; color: #64748b;">Medical Hallucination Detector</div>
</div>
""", unsafe_allow_html=True)
st.markdown("---")
st.markdown("""
<div style="padding: 0.6rem 0;">
<div style="font-size: 0.75rem; color: #64748b; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 0.5rem;">Two-Phase Pipeline</div>
<div style="color: #334155; font-size: 0.85rem; line-height: 2;">
<span style="color: #2563eb;">β </span> <b>Phase 1:</b> Retrieval & Generation<br>
<span style="color: #0d9488;">β‘</span> <b>Phase 2:</b> Decompose into Claims<br>
<span style="color: #d97706;">β’</span> <b>Phase 2:</b> NLI Verification<br>
<span style="color: #dc2626;">β£</span> <b>Phase 2:</b> Clinical Risk Scoring
</div>
</div>
""", unsafe_allow_html=True)
st.markdown("---")
st.markdown("""
<div style="padding: 0.4rem 0;">
<div style="font-size: 0.75rem; color: #64748b; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 0.5rem;">Tech Stack</div>
<div style="color: #475569; font-size: 0.78rem; line-height: 1.9;">
βοΈ <span style="color: #7c3aed;">llama-3.1-8b-instant (Groq)</span><br>
π‘οΈ <span style="color: #059669;">nli-deberta-v3-base</span><br>
π’ <span style="color: #2563eb;">FAISS Hybrid Retrieval</span>
</div>
</div>
""", unsafe_allow_html=True)
st.markdown("---")
if st.button("ποΈ Clear Chat History"):
st.session_state.messages = []
st.rerun()
# --- Main App Header ---
st.markdown("""
<div style="padding: 0.5rem 0 0.3rem;">
<h1 style="color: #1e293b; font-size: 1.6rem; margin-bottom: 0.2rem;">π₯ Bio-RAG: Clinical Fact-Checking</h1>
<p style="color: #64748b; font-size: 0.88rem; margin: 0;">Generates an answer and scores its risk of hallucination using NLI and Clinical Severity heuristics.</p>
</div>
""", unsafe_allow_html=True)
st.markdown("---")
# --- Chat State Management ---
if "messages" not in st.session_state:
st.session_state.messages = []
# --- Render Chat History ---
for msg in st.session_state.messages:
if msg["role"] == "user":
with st.chat_message("user"):
st.markdown(msg["content"])
elif msg["role"] == "assistant":
with st.chat_message("assistant"):
st.markdown(msg["content"])
# Display Risk Badge if it's an assistant message and successfully scored
if "result_data" in msg:
res = msg["result_data"]
if res.get("rejection_message"):
pass # Handled in the markdown output already implicitly, but can add badge:
else:
max_risk = res.get("max_risk_score", 0.0)
is_safe = res.get("safe", False)
if is_safe:
st.markdown(f"β
**Safe (Low Risk)**: Maximum Clinical Risk Score is **{max_risk:.4f}**")
else:
st.markdown(f"β οΈ **FLAGGED (High Risk)**: Maximum Clinical Risk Score is **{max_risk:.4f}**. Answer has been redacted.")
# Add an expander for the detailed claim breakdown
with st.expander("π View Verification Details"):
st.markdown("### Atomic Claims & Risk Scores")
for claim_check in res.get("claim_checks", []):
risk_val = claim_check.get("risk_score", 0.0)
st.markdown(f"""
**Claim:** {claim_check.get('claim')}
- **NLI Contradiction Prob:** {claim_check.get('nli_prob')}
- **Risk Score: {risk_val:.4f}**
---
""")
if res.get("evidence"):
st.markdown("### Retrieved Context (Top Passages)")
for idx, ev in enumerate(res.get("evidence", [])[:3]):
text = ev.get('text', str(ev)) if isinstance(ev, dict) else (ev.text if hasattr(ev, 'text') else str(ev))
st.info(f"**Document {idx+1}:** {text}")
# --- Handle User Input ---
if prompt := st.chat_input("Ask a medical question about diabetes (e.g., 'Is high insulin dose safe for mild sugar elevation?')..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("π€ Phase 1: Retrieving context & Generating answer via Groq..."):
# The spinner text updates are implicit, we just run the pipeline.
pass
with st.spinner("π‘οΈ Phase 2: Evaluating Claims & Calculating Clinical Risk (DeBERTa NLI)..."):
# Call the Pipeline
result = pipeline.ask(prompt)
answer_text = result.final_answer
st.markdown(answer_text)
if not result.rejection_message:
if result.safe:
st.success(f"β
**Safe (Low Risk)**: Maximum Clinical Risk Score is **{result.max_risk_score:.4f}**")
else:
st.error(f"β οΈ **FLAGGED (High Risk)**: Maximum Clinical Risk Score is **{result.max_risk_score:.4f}**. Answer has been redacted.")
with st.expander("π View Verification Details"):
st.markdown("### Atomic Claims & Risk Scores")
for claim_check in result.claim_checks:
risk_val = claim_check.get('risk_score', 0.0)
st.markdown(f"""
**Claim:** {claim_check.get('claim')}
- **NLI Contradiction Prob:** {claim_check.get('nli_prob')}
- **Risk Score: {risk_val:.4f}**
---
""")
if result.evidence:
st.markdown("### Retrieved Context (Top Passages)")
for idx, ev in enumerate(result.evidence[:3]):
text = ev.get('text', str(ev)) if isinstance(ev, dict) else (ev.text if hasattr(ev, 'text') else str(ev))
st.info(f"**Document {idx+1}:** {text}")
# Save assistant message to state with result data
# We need to make sure result.evidence is properly serialized or ignored to avoid st.session_state issues.
# result.to_dict() is safe as long as it handles RetrievedPassage correctly.
st.session_state.messages.append({
"role": "assistant",
"content": answer_text,
"result_data": result.to_dict()
}) |