File size: 8,156 Bytes
2a2c039
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import streamlit as st
import os
from src.bio_rag.pipeline import BioRAGPipeline

# --- Page Configuration ---
st.set_page_config(page_title="BioRAG Medical Assistant", page_icon="πŸ₯", layout="wide")

# --- Load Custom CSS ---
css_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "style.css")
if os.path.exists(css_path):
    with open(css_path) as f:
        st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)

# --- Cached Pipeline Initialization ---
@st.cache_resource(show_spinner=False)
def load_pipeline():
    """Load the full RAG pipeline (this will also load vector stores and models)"""
    return BioRAGPipeline()

# Initialize the pipeline silently behind the scenes
pipeline = load_pipeline()

# --- Sidebar ---
with st.sidebar:
    st.markdown("""
    <div style="text-align:center; padding: 1rem 0 0.5rem;">
        <div style="font-size: 2.5rem;">πŸ₯</div>
        <div style="font-size: 1.3rem; font-weight: 700; color: #1e293b; margin-top: 0.3rem;">BioRAG</div>
        <div style="font-size: 0.8rem; color: #64748b;">Medical Hallucination Detector</div>
    </div>
    """, unsafe_allow_html=True)

    st.markdown("---")

    st.markdown("""
    <div style="padding: 0.6rem 0;">
        <div style="font-size: 0.75rem; color: #64748b; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 0.5rem;">Two-Phase Pipeline</div>
        <div style="color: #334155; font-size: 0.85rem; line-height: 2;">
            <span style="color: #2563eb;">β‘ </span> <b>Phase 1:</b> Retrieval & Generation<br>
            <span style="color: #0d9488;">β‘‘</span> <b>Phase 2:</b> Decompose into Claims<br>
            <span style="color: #d97706;">β‘’</span> <b>Phase 2:</b> NLI Verification<br>
            <span style="color: #dc2626;">β‘£</span> <b>Phase 2:</b> Clinical Risk Scoring
        </div>
    </div>
    """, unsafe_allow_html=True)

    st.markdown("---")

    st.markdown("""
    <div style="padding: 0.4rem 0;">
        <div style="font-size: 0.75rem; color: #64748b; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 0.5rem;">Tech Stack</div>
        <div style="color: #475569; font-size: 0.78rem; line-height: 1.9;">
            ☁️ <span style="color: #7c3aed;">llama-3.1-8b-instant (Groq)</span><br>
            πŸ›‘οΈ <span style="color: #059669;">nli-deberta-v3-base</span><br>
            πŸ”’ <span style="color: #2563eb;">FAISS Hybrid Retrieval</span>
        </div>
    </div>
    """, unsafe_allow_html=True)

    st.markdown("---")
    if st.button("πŸ—‘οΈ Clear Chat History"):
        st.session_state.messages = []
        st.rerun()

# --- Main App Header ---
st.markdown("""
<div style="padding: 0.5rem 0 0.3rem;">
    <h1 style="color: #1e293b; font-size: 1.6rem; margin-bottom: 0.2rem;">πŸ₯ Bio-RAG: Clinical Fact-Checking</h1>
    <p style="color: #64748b; font-size: 0.88rem; margin: 0;">Generates an answer and scores its risk of hallucination using NLI and Clinical Severity heuristics.</p>
</div>
""", unsafe_allow_html=True)
st.markdown("---")

# --- Chat State Management ---
if "messages" not in st.session_state:
    st.session_state.messages = []

# --- Render Chat History ---
for msg in st.session_state.messages:
    if msg["role"] == "user":
        with st.chat_message("user"):
            st.markdown(msg["content"])
    elif msg["role"] == "assistant":
        with st.chat_message("assistant"):
            st.markdown(msg["content"])
            
            # Display Risk Badge if it's an assistant message and successfully scored
            if "result_data" in msg:
                res = msg["result_data"]
                
                if res.get("rejection_message"):
                    pass # Handled in the markdown output already implicitly, but can add badge:
                else:
                    max_risk = res.get("max_risk_score", 0.0)
                    is_safe = res.get("safe", False)
                    
                    if is_safe:
                        st.markdown(f"βœ… **Safe (Low Risk)**: Maximum Clinical Risk Score is **{max_risk:.4f}**")
                    else:
                        st.markdown(f"⚠️ **FLAGGED (High Risk)**: Maximum Clinical Risk Score is **{max_risk:.4f}**. Answer has been redacted.")
                        
                    # Add an expander for the detailed claim breakdown
                    with st.expander("πŸ” View Verification Details"):
                        st.markdown("### Atomic Claims & Risk Scores")
                        for claim_check in res.get("claim_checks", []):
                            risk_val = claim_check.get("risk_score", 0.0)
                            
                            st.markdown(f"""
                            **Claim:** {claim_check.get('claim')}  
                            - **NLI Contradiction Prob:** {claim_check.get('nli_prob')}  
                            - **Risk Score: {risk_val:.4f}**
                            ---
                            """)
                            
                        if res.get("evidence"):
                            st.markdown("### Retrieved Context (Top Passages)")
                            for idx, ev in enumerate(res.get("evidence", [])[:3]):
                                text = ev.get('text', str(ev)) if isinstance(ev, dict) else (ev.text if hasattr(ev, 'text') else str(ev))
                                st.info(f"**Document {idx+1}:** {text}")

# --- Handle User Input ---
if prompt := st.chat_input("Ask a medical question about diabetes (e.g., 'Is high insulin dose safe for mild sugar elevation?')..."):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    with st.chat_message("assistant"):
        with st.spinner("πŸ€– Phase 1: Retrieving context & Generating answer via Groq..."):
            # The spinner text updates are implicit, we just run the pipeline.
            pass
            
        with st.spinner("πŸ›‘οΈ Phase 2: Evaluating Claims & Calculating Clinical Risk (DeBERTa NLI)..."):
            # Call the Pipeline
            result = pipeline.ask(prompt)
            answer_text = result.final_answer

            st.markdown(answer_text)
            
            if not result.rejection_message:
                if result.safe:
                    st.success(f"βœ… **Safe (Low Risk)**: Maximum Clinical Risk Score is **{result.max_risk_score:.4f}**")
                else:
                    st.error(f"⚠️ **FLAGGED (High Risk)**: Maximum Clinical Risk Score is **{result.max_risk_score:.4f}**. Answer has been redacted.")
                
                with st.expander("πŸ” View Verification Details"):
                    st.markdown("### Atomic Claims & Risk Scores")
                    for claim_check in result.claim_checks:
                        risk_val = claim_check.get('risk_score', 0.0)
                        
                        st.markdown(f"""
                        **Claim:** {claim_check.get('claim')}  
                        - **NLI Contradiction Prob:** {claim_check.get('nli_prob')}  
                        - **Risk Score: {risk_val:.4f}**
                        ---
                        """)
                        
                    if result.evidence:
                        st.markdown("### Retrieved Context (Top Passages)")
                        for idx, ev in enumerate(result.evidence[:3]):
                            text = ev.get('text', str(ev)) if isinstance(ev, dict) else (ev.text if hasattr(ev, 'text') else str(ev))
                            st.info(f"**Document {idx+1}:** {text}")
                            
        # Save assistant message to state with result data
        # We need to make sure result.evidence is properly serialized or ignored to avoid st.session_state issues.
        # result.to_dict() is safe as long as it handles RetrievedPassage correctly.
        st.session_state.messages.append({
            "role": "assistant", 
            "content": answer_text,
            "result_data": result.to_dict()
        })