""" app/streamlit_app.py ===================== ALZDETECT-AI — Streamlit Chatbot Interface. WHAT: Chat UI for Alzheimer's research questions. WHY: Makes RAG pipeline accessible to researchers. WHO: End users — researchers, clinicians, students. WHERE: Runs locally or on HuggingFace Spaces. WHEN: Every user session. WORST-CASE DESIGN: - RAG pipeline error → clear error message, not crash - Empty question → caught before pipeline call - No chunks retrieved → honest "no evidence" message - Claude timeout → retry handled in rag_pipeline.py """ import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) import streamlit as st from loguru import logger from generation.rag_pipeline import RAGPipeline, ConversationMemory from vector_store.pinecone_store import QueryInput # ── Page config ─────────────────────────────────────────────────── st.set_page_config( page_title = "AlzDetect AI", page_icon = "🧠", layout = "wide", initial_sidebar_state = "expanded", ) # ── Load pipeline once — cached ─────────────────────────────────── @st.cache_resource def load_pipeline() -> RAGPipeline: """ Load RAG pipeline once and cache it. Streamlit reruns on every interaction — caching prevents reloading the model every time. Analogy: The doctor arrives at the hospital once in the morning — not for every patient visit. """ logger.info("[APP] Loading RAG pipeline...") pipeline = RAGPipeline() logger.info("[APP] Pipeline ready") return pipeline # ── Sidebar ─────────────────────────────────────────────────────── with st.sidebar: st.image("https://img.icons8.com/color/96/brain.png", width=80) st.title("AlzDetect AI") st.markdown("*Alzheimer's Research Assistant*") st.divider() st.markdown("### Analysis mode") analysis_mode = st.radio( "Choose analysis depth:", options=["Quick", "Deep"], captions=[ "Single Claude call — ~5 seconds", "Two-call agent — ~20 seconds, richer" ], index=0, ) st.divider() st.markdown("### Search settings") top_k = st.slider( "Papers to retrieve", min_value=3, max_value=20, value=10, help="More papers = richer context but slower response" ) year_from = st.selectbox( "Filter from year", options=[None, 2020, 2021, 2022, 2023, 2024, 2025], format_func=lambda x: "All years" if x is None else str(x), ) source_filter = st.selectbox( "Data source", options=[None, "pubmed", "adni"], format_func=lambda x: "All sources" if x is None else x.upper(), ) st.divider() st.markdown("### Pipeline stats") st.metric("Papers indexed", "19,637") st.metric("Chunks in Pinecone", "44,676") st.metric("Embedding model", "PubMedBERT") st.metric("ADNI records", "24,916") st.divider() if st.button("🗑️ Clear conversation", use_container_width=True): st.session_state.messages = [] if "conversation_memory" in st.session_state: st.session_state.conversation_memory.clear() st.rerun() st.divider() st.markdown( "⚠️ *For research purposes only. " "Not for clinical decisions.*" ) # ── Main UI ─────────────────────────────────────────────────────── st.title("🧠 AlzDetect AI") st.markdown( "Ask any question about Alzheimer's disease research. " "Answers are grounded in peer-reviewed PubMed papers with citations." ) # ── Chat history ────────────────────────────────────────────────── if "messages" not in st.session_state: st.session_state.messages = [] if "pipeline" not in st.session_state: st.session_state.pipeline = None if "conversation_memory" not in st.session_state: st.session_state.conversation_memory = ConversationMemory() # Display chat history for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # ── Example questions ───────────────────────────────────────────── if not st.session_state.messages: st.markdown("### Example questions") col1, col2 = st.columns(2) examples = [ "What blood biomarkers detect Alzheimer's early?", "How does pTau217 predict Alzheimer's progression?", "What is the role of APOE4 in Alzheimer's risk?", "What treatments are in clinical trials for Alzheimer's?", ] for i, example in enumerate(examples): col = col1 if i % 2 == 0 else col2 if col.button(example, key=f"example_{i}"): st.session_state.example_question = example st.rerun() # ── Handle example question clicks ─────────────────────────────── if "example_question" in st.session_state: question = st.session_state.pop("example_question") st.session_state.messages.append({ "role": "user", "content": question }) st.rerun() # ── Chat input ──────────────────────────────────────────────────── question = st.chat_input( "Ask about Alzheimer's research...", ) if question: # Validate — worst-case: empty question question = question.strip() if len(question) < 3: st.warning("Please ask a more specific question.") st.stop() # Add user message to history st.session_state.messages.append({ "role": "user", "content": question }) # Generate answer with st.chat_message("assistant"): mode_label = " Deep analysis..." if analysis_mode == "Deep" else "Searching 44,676 research chunks..." with st.spinner(mode_label): try: # Call pipeline or agent based on mode if analysis_mode == "Deep": from generation.summarizer_agent import SummarizerAgent @st.cache_resource def load_agent() -> SummarizerAgent: return SummarizerAgent() agent = load_agent() answer, diagnostic = agent.ask( question = question, top_k = top_k, source = source_filter, year_from = year_from, memory = st.session_state.conversation_memory, ) else: pipeline = load_pipeline() answer, diagnostic = pipeline.ask( question = question, top_k = top_k, source = source_filter, year_from = year_from, memory = st.session_state.conversation_memory, ) # Display answer response_text = answer.to_display() st.markdown(response_text) # Show diagnostic in expander with st.expander("🔍 Search details"): col1, col2, col3, col4 = st.columns(4) col1.metric("Chunks retrieved", diagnostic.chunks_retrieved) col2.metric("Top similarity", f"{diagnostic.top_score:.3f}") col3.metric("Response time", f"{diagnostic.total_time_ms:.0f}ms") col4.metric("Mode", "🔬 Deep" if analysis_mode == "Deep" else "⚡ Quick") # Add to history st.session_state.messages.append({ "role": "assistant", "content": response_text }) except Exception as e: error_msg = f"Pipeline error: {str(e)}"