AlzDetectAI / app.py
tpriyadata
feat: conversation memory verified β€” 159 tests passing
8fad980
"""
app/streamlit_app.py
=====================
ALZDETECT-AI β€” Streamlit Chatbot Interface.
WHAT: Chat UI for Alzheimer's research questions.
WHY: Makes RAG pipeline accessible to researchers.
WHO: End users β€” researchers, clinicians, students.
WHERE: Runs locally or on HuggingFace Spaces.
WHEN: Every user session.
WORST-CASE DESIGN:
- RAG pipeline error β†’ clear error message, not crash
- Empty question β†’ caught before pipeline call
- No chunks retrieved β†’ honest "no evidence" message
- Claude timeout β†’ retry handled in rag_pipeline.py
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import streamlit as st
from loguru import logger
from generation.rag_pipeline import RAGPipeline, ConversationMemory
from vector_store.pinecone_store import QueryInput
# ── Page config ───────────────────────────────────────────────────
st.set_page_config(
page_title = "AlzDetect AI",
page_icon = "🧠",
layout = "wide",
initial_sidebar_state = "expanded",
)
# ── Load pipeline once β€” cached ───────────────────────────────────
@st.cache_resource
def load_pipeline() -> RAGPipeline:
"""
Load RAG pipeline once and cache it.
Streamlit reruns on every interaction β€”
caching prevents reloading the model every time.
Analogy: The doctor arrives at the hospital once
in the morning β€” not for every patient visit.
"""
logger.info("[APP] Loading RAG pipeline...")
pipeline = RAGPipeline()
logger.info("[APP] Pipeline ready")
return pipeline
# ── Sidebar ───────────────────────────────────────────────────────
with st.sidebar:
st.image("https://img.icons8.com/color/96/brain.png", width=80)
st.title("AlzDetect AI")
st.markdown("*Alzheimer's Research Assistant*")
st.divider()
st.markdown("### Analysis mode")
analysis_mode = st.radio(
"Choose analysis depth:",
options=["Quick", "Deep"],
captions=[
"Single Claude call β€” ~5 seconds",
"Two-call agent β€” ~20 seconds, richer"
],
index=0,
)
st.divider()
st.markdown("### Search settings")
top_k = st.slider(
"Papers to retrieve",
min_value=3,
max_value=20,
value=10,
help="More papers = richer context but slower response"
)
year_from = st.selectbox(
"Filter from year",
options=[None, 2020, 2021, 2022, 2023, 2024, 2025],
format_func=lambda x: "All years" if x is None else str(x),
)
source_filter = st.selectbox(
"Data source",
options=[None, "pubmed", "adni"],
format_func=lambda x: "All sources" if x is None else x.upper(),
)
st.divider()
st.markdown("### Pipeline stats")
st.metric("Papers indexed", "19,637")
st.metric("Chunks in Pinecone", "44,676")
st.metric("Embedding model", "PubMedBERT")
st.metric("ADNI records", "24,916")
st.divider()
if st.button("πŸ—‘οΈ Clear conversation", use_container_width=True):
st.session_state.messages = []
if "conversation_memory" in st.session_state:
st.session_state.conversation_memory.clear()
st.rerun()
st.divider()
st.markdown(
"⚠️ *For research purposes only. "
"Not for clinical decisions.*"
)
# ── Main UI ───────────────────────────────────────────────────────
st.title("🧠 AlzDetect AI")
st.markdown(
"Ask any question about Alzheimer's disease research. "
"Answers are grounded in peer-reviewed PubMed papers with citations."
)
# ── Chat history ──────────────────────────────────────────────────
if "messages" not in st.session_state:
st.session_state.messages = []
if "pipeline" not in st.session_state:
st.session_state.pipeline = None
if "conversation_memory" not in st.session_state:
st.session_state.conversation_memory = ConversationMemory()
# Display chat history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# ── Example questions ─────────────────────────────────────────────
if not st.session_state.messages:
st.markdown("### Example questions")
col1, col2 = st.columns(2)
examples = [
"What blood biomarkers detect Alzheimer's early?",
"How does pTau217 predict Alzheimer's progression?",
"What is the role of APOE4 in Alzheimer's risk?",
"What treatments are in clinical trials for Alzheimer's?",
]
for i, example in enumerate(examples):
col = col1 if i % 2 == 0 else col2
if col.button(example, key=f"example_{i}"):
st.session_state.example_question = example
st.rerun()
# ── Handle example question clicks ───────────────────────────────
if "example_question" in st.session_state:
question = st.session_state.pop("example_question")
st.session_state.messages.append({
"role": "user",
"content": question
})
st.rerun()
# ── Chat input ────────────────────────────────────────────────────
question = st.chat_input(
"Ask about Alzheimer's research...",
)
if question:
# Validate β€” worst-case: empty question
question = question.strip()
if len(question) < 3:
st.warning("Please ask a more specific question.")
st.stop()
# Add user message to history
st.session_state.messages.append({
"role": "user",
"content": question
})
# Generate answer
with st.chat_message("assistant"):
mode_label = " Deep analysis..." if analysis_mode == "Deep" else "Searching 44,676 research chunks..."
with st.spinner(mode_label):
try:
# Call pipeline or agent based on mode
if analysis_mode == "Deep":
from generation.summarizer_agent import SummarizerAgent
@st.cache_resource
def load_agent() -> SummarizerAgent:
return SummarizerAgent()
agent = load_agent()
answer, diagnostic = agent.ask(
question = question,
top_k = top_k,
source = source_filter,
year_from = year_from,
memory = st.session_state.conversation_memory,
)
else:
pipeline = load_pipeline()
answer, diagnostic = pipeline.ask(
question = question,
top_k = top_k,
source = source_filter,
year_from = year_from,
memory = st.session_state.conversation_memory,
)
# Display answer
response_text = answer.to_display()
st.markdown(response_text)
# Show diagnostic in expander
with st.expander("πŸ” Search details"):
col1, col2, col3, col4 = st.columns(4)
col1.metric("Chunks retrieved", diagnostic.chunks_retrieved)
col2.metric("Top similarity", f"{diagnostic.top_score:.3f}")
col3.metric("Response time", f"{diagnostic.total_time_ms:.0f}ms")
col4.metric("Mode", "πŸ”¬ Deep" if analysis_mode == "Deep" else "⚑ Quick")
# Add to history
st.session_state.messages.append({
"role": "assistant",
"content": response_text
})
except Exception as e:
error_msg = f"Pipeline error: {str(e)}"