Spaces:
Running
Running
| import streamlit as st | |
| import os | |
| from src.bio_rag.pipeline import BioRAGPipeline | |
| # --- Page Configuration --- | |
| st.set_page_config(page_title="BioRAG Medical Assistant", page_icon="π₯", layout="wide") | |
| # --- Load Custom CSS --- | |
| css_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "style.css") | |
| if os.path.exists(css_path): | |
| with open(css_path) as f: | |
| st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
| # --- Cached Pipeline Initialization --- | |
| def load_pipeline(): | |
| """Load the full RAG pipeline (this will also load vector stores and models)""" | |
| return BioRAGPipeline() | |
| # Initialize the pipeline silently behind the scenes | |
| pipeline = load_pipeline() | |
| # --- Sidebar --- | |
| with st.sidebar: | |
| st.markdown(""" | |
| <div style="text-align:center; padding: 1rem 0 0.5rem;"> | |
| <div style="font-size: 2.5rem;">π₯</div> | |
| <div style="font-size: 1.3rem; font-weight: 700; color: #1e293b; margin-top: 0.3rem;">BioRAG</div> | |
| <div style="font-size: 0.8rem; color: #64748b;">Medical Hallucination Detector</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown("---") | |
| st.markdown(""" | |
| <div style="padding: 0.6rem 0;"> | |
| <div style="font-size: 0.75rem; color: #64748b; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 0.5rem;">Two-Phase Pipeline</div> | |
| <div style="color: #334155; font-size: 0.85rem; line-height: 2;"> | |
| <span style="color: #2563eb;">β </span> <b>Phase 1:</b> Retrieval & Generation<br> | |
| <span style="color: #0d9488;">β‘</span> <b>Phase 2:</b> Decompose into Claims<br> | |
| <span style="color: #d97706;">β’</span> <b>Phase 2:</b> NLI Verification<br> | |
| <span style="color: #dc2626;">β£</span> <b>Phase 2:</b> Clinical Risk Scoring | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown("---") | |
| st.markdown(""" | |
| <div style="padding: 0.4rem 0;"> | |
| <div style="font-size: 0.75rem; color: #64748b; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 0.5rem;">Tech Stack</div> | |
| <div style="color: #475569; font-size: 0.78rem; line-height: 1.9;"> | |
| βοΈ <span style="color: #7c3aed;">llama-3.1-8b-instant (Groq)</span><br> | |
| π‘οΈ <span style="color: #059669;">nli-deberta-v3-base</span><br> | |
| π’ <span style="color: #2563eb;">FAISS Hybrid Retrieval</span> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown("---") | |
| if st.button("ποΈ Clear Chat History"): | |
| st.session_state.messages = [] | |
| st.rerun() | |
| # --- Main App Header --- | |
| st.markdown(""" | |
| <div style="padding: 0.5rem 0 0.3rem;"> | |
| <h1 style="color: #1e293b; font-size: 1.6rem; margin-bottom: 0.2rem;">π₯ Bio-RAG: Clinical Fact-Checking</h1> | |
| <p style="color: #64748b; font-size: 0.88rem; margin: 0;">Generates an answer and scores its risk of hallucination using NLI and Clinical Severity heuristics.</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown("---") | |
| # --- Chat State Management --- | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # --- Render Chat History --- | |
| for msg in st.session_state.messages: | |
| if msg["role"] == "user": | |
| with st.chat_message("user"): | |
| st.markdown(msg["content"]) | |
| elif msg["role"] == "assistant": | |
| with st.chat_message("assistant"): | |
| st.markdown(msg["content"]) | |
| # Display Risk Badge if it's an assistant message and successfully scored | |
| if "result_data" in msg: | |
| res = msg["result_data"] | |
| if res.get("rejection_message"): | |
| pass # Handled in the markdown output already implicitly, but can add badge: | |
| else: | |
| max_risk = res.get("max_risk_score", 0.0) | |
| is_safe = res.get("safe", False) | |
| if is_safe: | |
| st.markdown(f"β **Safe (Low Risk)**: Maximum Clinical Risk Score is **{max_risk:.4f}**") | |
| else: | |
| st.markdown(f"β οΈ **FLAGGED (High Risk)**: Maximum Clinical Risk Score is **{max_risk:.4f}**. Answer has been redacted.") | |
| # Add an expander for the detailed claim breakdown | |
| with st.expander("π View Verification Details"): | |
| st.markdown("### Atomic Claims & Risk Scores") | |
| for claim_check in res.get("claim_checks", []): | |
| risk_val = claim_check.get("risk_score", 0.0) | |
| st.markdown(f""" | |
| **Claim:** {claim_check.get('claim')} | |
| - **NLI Contradiction Prob:** {claim_check.get('nli_prob')} | |
| - **Risk Score: {risk_val:.4f}** | |
| --- | |
| """) | |
| if res.get("evidence"): | |
| st.markdown("### Retrieved Context (Top Passages)") | |
| for idx, ev in enumerate(res.get("evidence", [])[:3]): | |
| text = ev.get('text', str(ev)) if isinstance(ev, dict) else (ev.text if hasattr(ev, 'text') else str(ev)) | |
| st.info(f"**Document {idx+1}:** {text}") | |
| # --- Handle User Input --- | |
| if prompt := st.chat_input("Ask a medical question about diabetes (e.g., 'Is high insulin dose safe for mild sugar elevation?')..."): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| with st.spinner("π€ Phase 1: Retrieving context & Generating answer via Groq..."): | |
| # The spinner text updates are implicit, we just run the pipeline. | |
| pass | |
| with st.spinner("π‘οΈ Phase 2: Evaluating Claims & Calculating Clinical Risk (DeBERTa NLI)..."): | |
| # Call the Pipeline | |
| result = pipeline.ask(prompt) | |
| answer_text = result.final_answer | |
| st.markdown(answer_text) | |
| if not result.rejection_message: | |
| if result.safe: | |
| st.success(f"β **Safe (Low Risk)**: Maximum Clinical Risk Score is **{result.max_risk_score:.4f}**") | |
| else: | |
| st.error(f"β οΈ **FLAGGED (High Risk)**: Maximum Clinical Risk Score is **{result.max_risk_score:.4f}**. Answer has been redacted.") | |
| with st.expander("π View Verification Details"): | |
| st.markdown("### Atomic Claims & Risk Scores") | |
| for claim_check in result.claim_checks: | |
| risk_val = claim_check.get('risk_score', 0.0) | |
| st.markdown(f""" | |
| **Claim:** {claim_check.get('claim')} | |
| - **NLI Contradiction Prob:** {claim_check.get('nli_prob')} | |
| - **Risk Score: {risk_val:.4f}** | |
| --- | |
| """) | |
| if result.evidence: | |
| st.markdown("### Retrieved Context (Top Passages)") | |
| for idx, ev in enumerate(result.evidence[:3]): | |
| text = ev.get('text', str(ev)) if isinstance(ev, dict) else (ev.text if hasattr(ev, 'text') else str(ev)) | |
| st.info(f"**Document {idx+1}:** {text}") | |
| # Save assistant message to state with result data | |
| # We need to make sure result.evidence is properly serialized or ignored to avoid st.session_state issues. | |
| # result.to_dict() is safe as long as it handles RetrievedPassage correctly. | |
| st.session_state.messages.append({ | |
| "role": "assistant", | |
| "content": answer_text, | |
| "result_data": result.to_dict() | |
| }) |