import streamlit as st
from rag_pipeline import build_rag_pipeline
from streamlit_extras.add_vertical_space import add_vertical_space
# --- PAGE CONFIG ---
st.set_page_config(
page_title="π¬ MoodMate",
page_icon="π«",
layout="centered",
)
# --- CUSTOM CSS (bubbles + badge) ---
st.markdown("""
""", unsafe_allow_html=True)
# --- HEADER ---
st.markdown('
π¬ MoodMate
', unsafe_allow_html=True)
st.markdown('Ask anything about personal, social, or business growth β powered by RAG + Gemini
', unsafe_allow_html=True)
add_vertical_space(2)
# --- LOAD PIPELINE ---
@st.cache_resource
def load_chain():
return build_rag_pipeline()
llm, retriever, rag_chain = load_chain()
# --- USER SETTINGS ---
st.markdown("### βοΈ Answer Selection Settings")
# Automatic vs Manual mode
auto_mode = st.checkbox("Automatic answer selection (default)", value=True)
# Manual answer type selection appears only if auto_mode is off
if not auto_mode:
answer_type = st.radio(
"Select answer type:",
("Dataset-Based Answer", "General Reasoning Answer"),
index=0
)
add_vertical_space(1)
# --- SESSION STATE MEMORY ---
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Ensure input_box key exists so it persists across runs
if "input_box" not in st.session_state:
st.session_state.input_box = ""
# --- LAYOUT: chat area + input at bottom ---
chat_col = st.container()
# Render chat area (so it updates live on each run)
with chat_col:
st.markdown("## π¬ Conversation")
chat_area = st.container()
with chat_area:
# Render each turn in order
for i, turn in enumerate(st.session_state.chat_history):
# User bubble (left)
st.markdown(f'π§ You: {turn["user"]}
', unsafe_allow_html=True)
# Assistant bubble with subtle badge
typ = turn.get("type", "General Reasoning")
badge_html = (
f'Dataset-Based'
if typ == "Dataset-Based Answer"
else f'General Reasoning'
)
st.markdown(f'π€ Assistant: {turn["ai"]} {badge_html}
', unsafe_allow_html=True)
# If dataset-based and has docs, show small expander for docs
if turn.get("type") == "Dataset-Based Answer" and turn.get("docs"):
with st.expander(f"π Top Retrieved Documents for message {i+1}"):
for d in turn["docs"][:3]:
parts = d.page_content.split("\n")
q_text = parts[0].replace("Q: ", "") if len(parts) > 0 else ""
a_text = parts[1].replace("A: ", "") if len(parts) > 1 else ""
st.markdown(
f'',
unsafe_allow_html=True
)
# --- SEND CALLBACK LOGIC ---
def handle_send():
query = st.session_state.input_box.strip()
if not query:
st.warning("Please enter a message.")
return
with st.spinner("π Thinking and retrieving relevant information..."):
# --- Build unified chat history for contextual prompting ---
N_keep = 6 # keep last 6 turns
history_for_prompt = st.session_state.chat_history[-N_keep:]
full_prompt = ""
for turn in history_for_prompt:
full_prompt += f"User: {turn['user']}\nAI: {turn['ai']}\n"
full_prompt += f"User: {query}\nAI:"
rag_answer, general_answer, docs = "", "", []
# --- AUTO MODE ---
if auto_mode:
# Step 1: Try dataset-based (RAG) first
rag_result = rag_chain({"question": query})
rag_answer = rag_result.get("answer", "")
docs = rag_result.get("source_documents", [])
# Step 2: Evaluate RAG answer quality
# Automatically decide whether to show the dataset-based answer or fall back to general reasoning
# Explanation:
# - any(kw in rag_answer.lower() for kw in fallback_keywords): checks if any "bad" keyword appears
# - len(rag_answer.strip()) < 50: checks if the dataset-based answer is too short (likely low quality)
# - not (...): inverts the condition β we show dataset answer only if itβs *good enough*
fallback_keywords = ["cannot answer", "no information", "based on the context", "i'm sorry"]
rag_too_short = len(rag_answer.strip()) < 50
rag_weak = any(kw in rag_answer.lower() for kw in fallback_keywords)
if rag_weak or rag_too_short:
# Step 3: Fallback to general reasoning ONLY if RAG is weak
general_response_obj = llm.invoke(full_prompt)
general_answer = getattr(general_response_obj, "content", str(general_response_obj))
chosen_answer = general_answer
chosen_type = "General Reasoning"
else:
chosen_answer = rag_answer
chosen_type = "Dataset-Based Answer"
# --- MANUAL MODE ---
else:
if answer_type == "Dataset-Based Answer":
rag_result = rag_chain({"question": query})
rag_answer = rag_result.get("answer", "")
docs = rag_result.get("source_documents", [])
chosen_answer = rag_answer
chosen_type = "Dataset-Based Answer"
else:
general_response_obj = llm.invoke(full_prompt)
general_answer = getattr(general_response_obj, "content", str(general_response_obj))
chosen_answer = general_answer
chosen_type = "General Reasoning"
# --- Append to unified chat history ---
st.session_state.chat_history.append({
"user": query,
"ai": chosen_answer,
"type": chosen_type,
"docs": docs if chosen_type == "Dataset-Based Answer" else None
})
# β
Clear input after sending
st.session_state.input_box = ""
# --- INPUT AREA (stays at bottom) ---
# --- FIXED INPUT BAR ---
st.markdown('', unsafe_allow_html=True)
query = st.text_input(
"π Type your message here...",
key="input_box",
placeholder="e.g. How can I improve my communication skills?",
label_visibility="collapsed"
)
col1, col2 = st.columns([0.2, 0.8])
with col1:
st.button("Send π¬", key="send_button", on_click=handle_send)
with col2:
st.button("π§Ή Clear Chat", key="clear_button", help="Clears conversation history (not persistent).", on_click=lambda: (
st.session_state.chat_history.clear(),
st.session_state.update({"input_box": ""}),
st.rerun()
))
st.markdown('
', unsafe_allow_html=True)