# ============================================================ # FILE: app.py (Consolidated Root) # PURPOSE: Single-file deployment for Safeguard AI. # Contains UI, Navigation, Groq Bot Logic, and In-Memory Analytics. # *UPDATED: Implements Lazy-Loading for SHAP XAI* # ============================================================ import os import sys import pandas as pd import streamlit as st import streamlit.components.v1 as components from dotenv import load_dotenv from groq import Groq # ───────────────────────────────────────────────────────────── # PATH SETUP & BACKEND IMPORTS # ───────────────────────────────────────────────────────────── # Add project root to path so we can import the heavy ML engines from src/ project_root = os.path.dirname(os.path.abspath(__file__)) if project_root not in sys.path: sys.path.append(project_root) # Import the core AI engines (ensure your src/ folder is uploaded to HF!) try: from src.core_model.predict import MindGuardPredictor from src.rag_engine.retriever import MindGuardRetriever from src.audio.speech_to_text import MindGuardAudioProcessor from src.explainability.shap_explainer import MindGuardSHAPExplainer except ImportError as e: st.error(f"Failed to import backend modules. Ensure the 'src' folder is uploaded. Details: {e}") st.stop() # ───────────────────────────────────────────────────────────── # PAGE CONFIGURATION # ───────────────────────────────────────────────────────────── st.set_page_config(page_title="Mindguard AI", page_icon="🛡️", layout="wide") # ───────────────────────────────────────────────────────────── # CORE CHATBOT ENGINE # ───────────────────────────────────────────────────────────── class SafeguardChatbot: """Orchestrates prediction, semantic search, and Groq LLM response.""" def __init__(self): # 1. Load API Key load_dotenv(os.path.join(project_root, ".env")) api_key = os.environ.get("GROQ_API_KEY") if not api_key: st.error("❌ GROQ_API_KEY not found in Hugging Face Secrets!") st.stop() self.client = Groq(api_key=api_key) # 2. Wake up tools self.predictor = MindGuardPredictor() self.retriever = MindGuardRetriever() self.audio_processor = MindGuardAudioProcessor() self.system_prompt = """ You are Safeguard AI, a highly empathetic, clinical-grade mental health AI. Your goal is to de-escalate emotional distress and provide actionable coping strategies. STRICT RULES: 1. NEVER hallucinate medical advice. ONLY use the 'Clinical Strategy' provided in the prompt. 2. Keep your response conversational, warm, and easy to read (use short paragraphs). 3. Do not sound like a robot reading a textbook. Weave the clinical strategy naturally into your empathy. 4. If the Risk Level is 'High', prioritize grounding the user immediately. """ def generate_response(self, user_input, chat_history_list): """Pipeline: Predict -> Retrieve -> Remember -> Generate.""" # 1. Core Model Prediction prediction = self.predictor.predict(user_input) emotion = prediction['emotion'] risk = prediction['risk_level'] # 2. RAG Retrieval strategy = self.retriever.get_coping_strategy( user_query=user_input, emotion_filter=emotion ) # 3. Extract recent history directly from Streamlit session state history_text = "No previous conversation." if chat_history_list: recent = chat_history_list[-4:] # Grab last two exchanges history_text = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in recent]) # 4. Prompt Engineering augmented_prompt = f""" --- RECENT CONVERSATION HISTORY --- {history_text} --- CURRENT SITUATION --- User's New Message: "{user_input}" AI Core Diagnosis: {emotion} Assessed Risk Level: {risk} Required Clinical Strategy to Teach the User: {strategy} Draft your response to the user's new message now: """ # 5. Groq Generation chat_completion = self.client.chat.completions.create( messages=[ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": augmented_prompt} ], model="llama-3.3-70b-versatile", temperature=0.3, ) final_response = chat_completion.choices[0].message.content return final_response, emotion, risk # ───────────────────────────────────────────────────────────── # RESOURCE CACHING # ───────────────────────────────────────────────────────────── @st.cache_resource def get_safeguard_bot(): return SafeguardChatbot() @st.cache_resource def get_shap_explainer(): return MindGuardSHAPExplainer() # ───────────────────────────────────────────────────────────── # UI HELPERS (Badges & Formatting) # ───────────────────────────────────────────────────────────── RISK_COLORS = { "High": ("#ff4b4b", "white"), "Medium": ("#ffa500", "white"), "Low": ("#21c354", "white") } EMOTION_COLORS = { "Suicidal": "#ff4b4b", "Depression": "#e05260", "Anxiety": "#e07052", "Bipolar": "#c0392b", "Stress": "#e67e22", "Personality disorder": "#9b59b6", "joy": "#21c354", "love": "#2ecc71", "gratitude": "#27ae60", "optimism": "#16a085", "Normal": "#3498db", "neutral": "#3498db", "sadness": "#e08000", "fear": "#e74c3c", "anger": "#c0392b" } def _emotion_badge(emotion: str) -> str: color = EMOTION_COLORS.get(emotion, "#555555") return f'🧠 {emotion}' def _risk_badge(risk: str) -> str: bg, fg = RISK_COLORS.get(risk, ("#888888", "white")) icon = {"High": "🚨", "Medium": "⚠️", "Low": "✅"}.get(risk, "•") return f'{icon} Risk: {risk}' # ───────────────────────────────────────────────────────────── # PAGE 1: CHAT COMPANION # ───────────────────────────────────────────────────────────── def render_chat(): st.title("🧠 Mindguard AI Companion") st.markdown("Your clinical-grade, empathetic AI. Type a message or upload a voice note.") bot = get_safeguard_bot() if "messages" not in st.session_state: st.session_state.messages = [] # SIDEBAR: Audio & Controls with st.sidebar: st.header("🎙️ Voice Input") audio_value = st.audio_input("Record a voice note") if audio_value: current_audio_size = len(audio_value.getvalue()) if ("last_audio_size" not in st.session_state or st.session_state.last_audio_size != current_audio_size): st.session_state.last_audio_size = current_audio_size os.makedirs(os.path.join(project_root, "data", "raw"), exist_ok=True) temp_path = os.path.join(project_root, "data", "raw", "temp_streamlit.wav") with open(temp_path, "wb") as f: f.write(audio_value.getbuffer()) with st.spinner("Transcribing via Whisper…"): transcribed_text = bot.audio_processor.transcribe(temp_path) # --- DEFERRED EXECUTION FLAGS --- st.session_state.latest_analyzable_text = transcribed_text st.session_state.shap_is_stale = True response, emotion, risk = bot.generate_response(transcribed_text, st.session_state.messages) st.session_state.messages.append({"role": "user", "content": f"🎤 **Voice Note:** *{transcribed_text}*"}) st.session_state.messages.append({ "role": "assistant", "content": response, "emotion": emotion, "risk": risk }) st.divider() if st.button("🗑️ Clear Chat"): st.session_state.messages = [] st.session_state.pop("last_audio_size", None) st.session_state.pop("latest_analyzable_text", None) st.session_state.pop("shap_is_stale", None) st.rerun() # MAIN WINDOW: Chat History for msg in st.session_state.messages: with st.chat_message(msg["role"]): st.markdown(msg["content"]) if msg["role"] == "assistant" and "emotion" in msg: col1, col2 = st.columns([1, 1]) with col1: st.markdown(_emotion_badge(msg["emotion"]), unsafe_allow_html=True) with col2: st.markdown(_risk_badge(msg["risk"]), unsafe_allow_html=True) # MAIN WINDOW: Text Input if prompt := st.chat_input("How are you feeling right now?"): st.session_state.messages.append({"role": "user", "content": prompt}) # --- DEFERRED EXECUTION FLAGS --- st.session_state.latest_analyzable_text = prompt st.session_state.shap_is_stale = True with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): with st.spinner("Diagnosing and retrieving clinical strategy…"): response, emotion, risk = bot.generate_response(prompt, st.session_state.messages) st.markdown(response) col1, col2 = st.columns([1, 1]) with col1: st.markdown(_emotion_badge(emotion), unsafe_allow_html=True) with col2: st.markdown(_risk_badge(risk), unsafe_allow_html=True) st.session_state.messages.append({ "role": "assistant", "content": response, "emotion": emotion, "risk": risk }) # ───────────────────────────────────────────────────────────── # PAGE 2: CLINICAL DASHBOARD # ───────────────────────────────────────────────────────────── def render_dashboard(): st.title("📊 Clinical Overview") st.markdown("Real-time emotional tracking and XAI reports for the current session.") tab1, tab2 = st.tabs(["📈 Analytics", "🔬 XAI Report"]) SHAP_HTML_PATH = os.path.join(project_root, "artifacts", "shap_report.html") with tab1: history_data = [] for msg in st.session_state.get("messages", []): if msg["role"] == "assistant" and "emotion" in msg: history_data.append({ "diagnosed_emotion": msg["emotion"], "risk_level": msg["risk"], "timestamp": pd.Timestamp.now().strftime("%H:%M:%S") }) if not history_data: st.info("No data yet. Start chatting to generate live analytics.") return df = pd.DataFrame(history_data) total = len(df) high_risk = len(df[df["risk_level"] == "High"]) top_emotion = df["diagnosed_emotion"].mode()[0] if not df.empty else "N/A" col1, col2, col3 = st.columns(3) col1.metric("Interactions This Session", total) col2.metric("High Risk Flags", high_risk, delta_color="inverse") col3.metric("Primary Emotion", top_emotion) st.divider() st.subheader("Risk Level Summary") risk_cols = st.columns(3) for i, level in enumerate(["High", "Medium", "Low"]): count = len(df[df["risk_level"] == level]) icons = {"High": "🚨", "Medium": "⚠️", "Low": "✅"} risk_cols[i].metric(f"{icons[level]} {level}", count) st.divider() chart_col1, chart_col2 = st.columns(2) with chart_col1: st.subheader("Emotion Frequency") st.bar_chart(df["diagnosed_emotion"].value_counts()) with chart_col2: st.subheader("Risk Level Distribution") st.bar_chart(df["risk_level"].value_counts()) with tab2: st.subheader("🔬 Last SHAP Word-Level Explanation") # --- THE FIX: Lazy Loading Execution --- if st.session_state.get("shap_is_stale", False) and "latest_analyzable_text" in st.session_state: with st.spinner("Calculating XAI feature importance for the latest message... (This takes a moment)"): shap_ex = get_shap_explainer() shap_ex.generate_visual_report(st.session_state.latest_analyzable_text) st.session_state.shap_is_stale = False # Render the HTML if it exists if os.path.exists(SHAP_HTML_PATH): with open(SHAP_HTML_PATH, "r", encoding="utf-8") as f: shap_html = f.read() safe_html = f'