import streamlit as st import os from src.bio_rag.pipeline import BioRAGPipeline # --- Page Configuration --- st.set_page_config(page_title="BioRAG Medical Assistant", page_icon="🏥", layout="wide") # --- Load Custom CSS --- css_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "style.css") if os.path.exists(css_path): with open(css_path) as f: st.markdown(f"", unsafe_allow_html=True) # --- Cached Pipeline Initialization --- @st.cache_resource(show_spinner=False) def load_pipeline(): """Load the full RAG pipeline (this will also load vector stores and models)""" return BioRAGPipeline() # Initialize the pipeline silently behind the scenes pipeline = load_pipeline() # --- Sidebar --- with st.sidebar: st.markdown("""
🏥
BioRAG
Medical Hallucination Detector
""", unsafe_allow_html=True) st.markdown("---") st.markdown("""
Two-Phase Pipeline
Phase 1: Retrieval & Generation
Phase 2: Decompose into Claims
Phase 2: NLI Verification
Phase 2: Clinical Risk Scoring
""", unsafe_allow_html=True) st.markdown("---") st.markdown("""
Tech Stack
☁️ llama-3.1-8b-instant (Groq)
🛡️ nli-deberta-v3-base
🔢 FAISS Hybrid Retrieval
""", unsafe_allow_html=True) st.markdown("---") if st.button("🗑️ Clear Chat History"): st.session_state.messages = [] st.rerun() # --- Main App Header --- st.markdown("""

🏥 Bio-RAG: Clinical Fact-Checking

Generates an answer and scores its risk of hallucination using NLI and Clinical Severity heuristics.

""", unsafe_allow_html=True) st.markdown("---") # --- Chat State Management --- if "messages" not in st.session_state: st.session_state.messages = [] # --- Render Chat History --- for msg in st.session_state.messages: if msg["role"] == "user": with st.chat_message("user"): st.markdown(msg["content"]) elif msg["role"] == "assistant": with st.chat_message("assistant"): st.markdown(msg["content"]) # Display Risk Badge if it's an assistant message and successfully scored if "result_data" in msg: res = msg["result_data"] if res.get("rejection_message"): pass # Handled in the markdown output already implicitly, but can add badge: else: max_risk = res.get("max_risk_score", 0.0) is_safe = res.get("safe", False) if is_safe: st.markdown(f"✅ **Safe (Low Risk)**: Maximum Clinical Risk Score is **{max_risk:.4f}**") else: st.markdown(f"⚠️ **FLAGGED (High Risk)**: Maximum Clinical Risk Score is **{max_risk:.4f}**. Answer has been redacted.") # Add an expander for the detailed claim breakdown with st.expander("🔍 View Verification Details"): st.markdown("### Atomic Claims & Risk Scores") for claim_check in res.get("claim_checks", []): risk_val = claim_check.get("risk_score", 0.0) st.markdown(f""" **Claim:** {claim_check.get('claim')} - **NLI Contradiction Prob:** {claim_check.get('nli_prob')} - **Risk Score: {risk_val:.4f}** --- """) if res.get("evidence"): st.markdown("### Retrieved Context (Top Passages)") for idx, ev in enumerate(res.get("evidence", [])[:3]): text = ev.get('text', str(ev)) if isinstance(ev, dict) else (ev.text if hasattr(ev, 'text') else str(ev)) st.info(f"**Document {idx+1}:** {text}") # --- Handle User Input --- if prompt := st.chat_input("Ask a medical question about diabetes (e.g., 'Is high insulin dose safe for mild sugar elevation?')..."): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): with st.spinner("🤖 Phase 1: Retrieving context & Generating answer via Groq..."): # The spinner text updates are implicit, we just run the pipeline. pass with st.spinner("🛡️ Phase 2: Evaluating Claims & Calculating Clinical Risk (DeBERTa NLI)..."): # Call the Pipeline result = pipeline.ask(prompt) answer_text = result.final_answer st.markdown(answer_text) if not result.rejection_message: if result.safe: st.success(f"✅ **Safe (Low Risk)**: Maximum Clinical Risk Score is **{result.max_risk_score:.4f}**") else: st.error(f"⚠️ **FLAGGED (High Risk)**: Maximum Clinical Risk Score is **{result.max_risk_score:.4f}**. Answer has been redacted.") with st.expander("🔍 View Verification Details"): st.markdown("### Atomic Claims & Risk Scores") for claim_check in result.claim_checks: risk_val = claim_check.get('risk_score', 0.0) st.markdown(f""" **Claim:** {claim_check.get('claim')} - **NLI Contradiction Prob:** {claim_check.get('nli_prob')} - **Risk Score: {risk_val:.4f}** --- """) if result.evidence: st.markdown("### Retrieved Context (Top Passages)") for idx, ev in enumerate(result.evidence[:3]): text = ev.get('text', str(ev)) if isinstance(ev, dict) else (ev.text if hasattr(ev, 'text') else str(ev)) st.info(f"**Document {idx+1}:** {text}") # Save assistant message to state with result data # We need to make sure result.evidence is properly serialized or ignored to avoid st.session_state issues. # result.to_dict() is safe as long as it handles RetrievedPassage correctly. st.session_state.messages.append({ "role": "assistant", "content": answer_text, "result_data": result.to_dict() })