import streamlit as st from rag_pipeline import build_rag_pipeline from streamlit_extras.add_vertical_space import add_vertical_space # --- PAGE CONFIG --- st.set_page_config( page_title="πŸ’¬ MoodMate", page_icon="πŸ’«", layout="centered", ) # --- CUSTOM CSS (bubbles + badge) --- st.markdown(""" """, unsafe_allow_html=True) # --- HEADER --- st.markdown('
πŸ’¬ MoodMate
', unsafe_allow_html=True) st.markdown('
Ask anything about personal, social, or business growth β€” powered by RAG + Gemini
', unsafe_allow_html=True) add_vertical_space(2) # --- LOAD PIPELINE --- @st.cache_resource def load_chain(): return build_rag_pipeline() llm, retriever, rag_chain = load_chain() # --- USER SETTINGS --- st.markdown("### βš™οΈ Answer Selection Settings") # Automatic vs Manual mode auto_mode = st.checkbox("Automatic answer selection (default)", value=True) # Manual answer type selection appears only if auto_mode is off if not auto_mode: answer_type = st.radio( "Select answer type:", ("Dataset-Based Answer", "General Reasoning Answer"), index=0 ) add_vertical_space(1) # --- SESSION STATE MEMORY --- if "chat_history" not in st.session_state: st.session_state.chat_history = [] # Ensure input_box key exists so it persists across runs if "input_box" not in st.session_state: st.session_state.input_box = "" # --- LAYOUT: chat area + input at bottom --- chat_col = st.container() # Render chat area (so it updates live on each run) with chat_col: st.markdown("## πŸ’¬ Conversation") chat_area = st.container() with chat_area: # Render each turn in order for i, turn in enumerate(st.session_state.chat_history): # User bubble (left) st.markdown(f'
πŸ§‘ You: {turn["user"]}
', unsafe_allow_html=True) # Assistant bubble with subtle badge typ = turn.get("type", "General Reasoning") badge_html = ( f'Dataset-Based' if typ == "Dataset-Based Answer" else f'General Reasoning' ) st.markdown(f'
πŸ€– Assistant: {turn["ai"]} {badge_html}
', unsafe_allow_html=True) # If dataset-based and has docs, show small expander for docs if turn.get("type") == "Dataset-Based Answer" and turn.get("docs"): with st.expander(f"πŸ“‚ Top Retrieved Documents for message {i+1}"): for d in turn["docs"][:3]: parts = d.page_content.split("\n") q_text = parts[0].replace("Q: ", "") if len(parts) > 0 else "" a_text = parts[1].replace("A: ", "") if len(parts) > 1 else "" st.markdown( f'
Q: {q_text}
A: {a_text}
', unsafe_allow_html=True ) # --- SEND CALLBACK LOGIC --- def handle_send(): query = st.session_state.input_box.strip() if not query: st.warning("Please enter a message.") return with st.spinner("πŸ” Thinking and retrieving relevant information..."): # --- Build unified chat history for contextual prompting --- N_keep = 6 # keep last 6 turns history_for_prompt = st.session_state.chat_history[-N_keep:] full_prompt = "" for turn in history_for_prompt: full_prompt += f"User: {turn['user']}\nAI: {turn['ai']}\n" full_prompt += f"User: {query}\nAI:" rag_answer, general_answer, docs = "", "", [] # --- AUTO MODE --- if auto_mode: # Step 1: Try dataset-based (RAG) first rag_result = rag_chain({"question": query}) rag_answer = rag_result.get("answer", "") docs = rag_result.get("source_documents", []) # Step 2: Evaluate RAG answer quality # Automatically decide whether to show the dataset-based answer or fall back to general reasoning # Explanation: # - any(kw in rag_answer.lower() for kw in fallback_keywords): checks if any "bad" keyword appears # - len(rag_answer.strip()) < 50: checks if the dataset-based answer is too short (likely low quality) # - not (...): inverts the condition β€” we show dataset answer only if it’s *good enough* fallback_keywords = ["cannot answer", "no information", "based on the context", "i'm sorry"] rag_too_short = len(rag_answer.strip()) < 50 rag_weak = any(kw in rag_answer.lower() for kw in fallback_keywords) if rag_weak or rag_too_short: # Step 3: Fallback to general reasoning ONLY if RAG is weak general_response_obj = llm.invoke(full_prompt) general_answer = getattr(general_response_obj, "content", str(general_response_obj)) chosen_answer = general_answer chosen_type = "General Reasoning" else: chosen_answer = rag_answer chosen_type = "Dataset-Based Answer" # --- MANUAL MODE --- else: if answer_type == "Dataset-Based Answer": rag_result = rag_chain({"question": query}) rag_answer = rag_result.get("answer", "") docs = rag_result.get("source_documents", []) chosen_answer = rag_answer chosen_type = "Dataset-Based Answer" else: general_response_obj = llm.invoke(full_prompt) general_answer = getattr(general_response_obj, "content", str(general_response_obj)) chosen_answer = general_answer chosen_type = "General Reasoning" # --- Append to unified chat history --- st.session_state.chat_history.append({ "user": query, "ai": chosen_answer, "type": chosen_type, "docs": docs if chosen_type == "Dataset-Based Answer" else None }) # βœ… Clear input after sending st.session_state.input_box = "" # --- INPUT AREA (stays at bottom) --- # --- FIXED INPUT BAR --- st.markdown('
', unsafe_allow_html=True) query = st.text_input( "πŸ’­ Type your message here...", key="input_box", placeholder="e.g. How can I improve my communication skills?", label_visibility="collapsed" ) col1, col2 = st.columns([0.2, 0.8]) with col1: st.button("Send πŸ’¬", key="send_button", on_click=handle_send) with col2: st.button("🧹 Clear Chat", key="clear_button", help="Clears conversation history (not persistent).", on_click=lambda: ( st.session_state.chat_history.clear(), st.session_state.update({"input_box": ""}), st.rerun() )) st.markdown('
', unsafe_allow_html=True)