Spaces:
Sleeping
Sleeping
| # ============================================================ | |
| # FILE: app.py (Consolidated Root) | |
| # PURPOSE: Single-file deployment for Safeguard AI. | |
| # Contains UI, Navigation, Groq Bot Logic, and In-Memory Analytics. | |
| # *UPDATED: Implements Lazy-Loading for SHAP XAI* | |
| # ============================================================ | |
| import os | |
| import sys | |
| import pandas as pd | |
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| from dotenv import load_dotenv | |
| from groq import Groq | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PATH SETUP & BACKEND IMPORTS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Add project root to path so we can import the heavy ML engines from src/ | |
| project_root = os.path.dirname(os.path.abspath(__file__)) | |
| if project_root not in sys.path: | |
| sys.path.append(project_root) | |
| # Import the core AI engines (ensure your src/ folder is uploaded to HF!) | |
| try: | |
| from src.core_model.predict import MindGuardPredictor | |
| from src.rag_engine.retriever import MindGuardRetriever | |
| from src.audio.speech_to_text import MindGuardAudioProcessor | |
| from src.explainability.shap_explainer import MindGuardSHAPExplainer | |
| except ImportError as e: | |
| st.error(f"Failed to import backend modules. Ensure the 'src' folder is uploaded. Details: {e}") | |
| st.stop() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PAGE CONFIGURATION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.set_page_config(page_title="Mindguard AI", page_icon="π‘οΈ", layout="wide") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CORE CHATBOT ENGINE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SafeguardChatbot: | |
| """Orchestrates prediction, semantic search, and Groq LLM response.""" | |
| def __init__(self): | |
| # 1. Load API Key | |
| load_dotenv(os.path.join(project_root, ".env")) | |
| api_key = os.environ.get("GROQ_API_KEY") | |
| if not api_key: | |
| st.error("β GROQ_API_KEY not found in Hugging Face Secrets!") | |
| st.stop() | |
| self.client = Groq(api_key=api_key) | |
| # 2. Wake up tools | |
| self.predictor = MindGuardPredictor() | |
| self.retriever = MindGuardRetriever() | |
| self.audio_processor = MindGuardAudioProcessor() | |
| self.system_prompt = """ | |
| You are Safeguard AI, a highly empathetic, clinical-grade mental health AI. | |
| Your goal is to de-escalate emotional distress and provide actionable coping strategies. | |
| STRICT RULES: | |
| 1. NEVER hallucinate medical advice. ONLY use the 'Clinical Strategy' provided in the prompt. | |
| 2. Keep your response conversational, warm, and easy to read (use short paragraphs). | |
| 3. Do not sound like a robot reading a textbook. Weave the clinical strategy naturally into your empathy. | |
| 4. If the Risk Level is 'High', prioritize grounding the user immediately. | |
| """ | |
| def generate_response(self, user_input, chat_history_list): | |
| """Pipeline: Predict -> Retrieve -> Remember -> Generate.""" | |
| # 1. Core Model Prediction | |
| prediction = self.predictor.predict(user_input) | |
| emotion = prediction['emotion'] | |
| risk = prediction['risk_level'] | |
| # 2. RAG Retrieval | |
| strategy = self.retriever.get_coping_strategy( | |
| user_query=user_input, | |
| emotion_filter=emotion | |
| ) | |
| # 3. Extract recent history directly from Streamlit session state | |
| history_text = "No previous conversation." | |
| if chat_history_list: | |
| recent = chat_history_list[-4:] # Grab last two exchanges | |
| history_text = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in recent]) | |
| # 4. Prompt Engineering | |
| augmented_prompt = f""" | |
| --- RECENT CONVERSATION HISTORY --- | |
| {history_text} | |
| --- CURRENT SITUATION --- | |
| User's New Message: "{user_input}" | |
| AI Core Diagnosis: {emotion} | |
| Assessed Risk Level: {risk} | |
| Required Clinical Strategy to Teach the User: | |
| {strategy} | |
| Draft your response to the user's new message now: | |
| """ | |
| # 5. Groq Generation | |
| chat_completion = self.client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": self.system_prompt}, | |
| {"role": "user", "content": augmented_prompt} | |
| ], | |
| model="llama-3.3-70b-versatile", | |
| temperature=0.3, | |
| ) | |
| final_response = chat_completion.choices[0].message.content | |
| return final_response, emotion, risk | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # RESOURCE CACHING | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_safeguard_bot(): | |
| return SafeguardChatbot() | |
| def get_shap_explainer(): | |
| return MindGuardSHAPExplainer() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # UI HELPERS (Badges & Formatting) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| RISK_COLORS = { | |
| "High": ("#ff4b4b", "white"), "Medium": ("#ffa500", "white"), "Low": ("#21c354", "white") | |
| } | |
| EMOTION_COLORS = { | |
| "Suicidal": "#ff4b4b", "Depression": "#e05260", "Anxiety": "#e07052", "Bipolar": "#c0392b", | |
| "Stress": "#e67e22", "Personality disorder": "#9b59b6", "joy": "#21c354", "love": "#2ecc71", | |
| "gratitude": "#27ae60", "optimism": "#16a085", "Normal": "#3498db", "neutral": "#3498db", | |
| "sadness": "#e08000", "fear": "#e74c3c", "anger": "#c0392b" | |
| } | |
| def _emotion_badge(emotion: str) -> str: | |
| color = EMOTION_COLORS.get(emotion, "#555555") | |
| return f'<span style="background:{color};color:white;padding:3px 10px;border-radius:12px;font-size:13px;font-weight:600;">π§ {emotion}</span>' | |
| def _risk_badge(risk: str) -> str: | |
| bg, fg = RISK_COLORS.get(risk, ("#888888", "white")) | |
| icon = {"High": "π¨", "Medium": "β οΈ", "Low": "β "}.get(risk, "β’") | |
| return f'<span style="background:{bg};color:{fg};padding:3px 10px;border-radius:12px;font-size:13px;font-weight:600;">{icon} Risk: {risk}</span>' | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PAGE 1: CHAT COMPANION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def render_chat(): | |
| st.title("π§ Mindguard AI Companion") | |
| st.markdown("Your clinical-grade, empathetic AI. Type a message or upload a voice note.") | |
| bot = get_safeguard_bot() | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # SIDEBAR: Audio & Controls | |
| with st.sidebar: | |
| st.header("ποΈ Voice Input") | |
| audio_value = st.audio_input("Record a voice note") | |
| if audio_value: | |
| current_audio_size = len(audio_value.getvalue()) | |
| if ("last_audio_size" not in st.session_state or st.session_state.last_audio_size != current_audio_size): | |
| st.session_state.last_audio_size = current_audio_size | |
| os.makedirs(os.path.join(project_root, "data", "raw"), exist_ok=True) | |
| temp_path = os.path.join(project_root, "data", "raw", "temp_streamlit.wav") | |
| with open(temp_path, "wb") as f: | |
| f.write(audio_value.getbuffer()) | |
| with st.spinner("Transcribing via Whisperβ¦"): | |
| transcribed_text = bot.audio_processor.transcribe(temp_path) | |
| # --- DEFERRED EXECUTION FLAGS --- | |
| st.session_state.latest_analyzable_text = transcribed_text | |
| st.session_state.shap_is_stale = True | |
| response, emotion, risk = bot.generate_response(transcribed_text, st.session_state.messages) | |
| st.session_state.messages.append({"role": "user", "content": f"π€ **Voice Note:** *{transcribed_text}*"}) | |
| st.session_state.messages.append({ | |
| "role": "assistant", "content": response, "emotion": emotion, "risk": risk | |
| }) | |
| st.divider() | |
| if st.button("ποΈ Clear Chat"): | |
| st.session_state.messages = [] | |
| st.session_state.pop("last_audio_size", None) | |
| st.session_state.pop("latest_analyzable_text", None) | |
| st.session_state.pop("shap_is_stale", None) | |
| st.rerun() | |
| # MAIN WINDOW: Chat History | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): | |
| st.markdown(msg["content"]) | |
| if msg["role"] == "assistant" and "emotion" in msg: | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: st.markdown(_emotion_badge(msg["emotion"]), unsafe_allow_html=True) | |
| with col2: st.markdown(_risk_badge(msg["risk"]), unsafe_allow_html=True) | |
| # MAIN WINDOW: Text Input | |
| if prompt := st.chat_input("How are you feeling right now?"): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # --- DEFERRED EXECUTION FLAGS --- | |
| st.session_state.latest_analyzable_text = prompt | |
| st.session_state.shap_is_stale = True | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| with st.spinner("Diagnosing and retrieving clinical strategyβ¦"): | |
| response, emotion, risk = bot.generate_response(prompt, st.session_state.messages) | |
| st.markdown(response) | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: st.markdown(_emotion_badge(emotion), unsafe_allow_html=True) | |
| with col2: st.markdown(_risk_badge(risk), unsafe_allow_html=True) | |
| st.session_state.messages.append({ | |
| "role": "assistant", "content": response, "emotion": emotion, "risk": risk | |
| }) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PAGE 2: CLINICAL DASHBOARD | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def render_dashboard(): | |
| st.title("π Clinical Overview") | |
| st.markdown("Real-time emotional tracking and XAI reports for the current session.") | |
| tab1, tab2 = st.tabs(["π Analytics", "π¬ XAI Report"]) | |
| SHAP_HTML_PATH = os.path.join(project_root, "artifacts", "shap_report.html") | |
| with tab1: | |
| history_data = [] | |
| for msg in st.session_state.get("messages", []): | |
| if msg["role"] == "assistant" and "emotion" in msg: | |
| history_data.append({ | |
| "diagnosed_emotion": msg["emotion"], | |
| "risk_level": msg["risk"], | |
| "timestamp": pd.Timestamp.now().strftime("%H:%M:%S") | |
| }) | |
| if not history_data: | |
| st.info("No data yet. Start chatting to generate live analytics.") | |
| return | |
| df = pd.DataFrame(history_data) | |
| total = len(df) | |
| high_risk = len(df[df["risk_level"] == "High"]) | |
| top_emotion = df["diagnosed_emotion"].mode()[0] if not df.empty else "N/A" | |
| col1, col2, col3 = st.columns(3) | |
| col1.metric("Interactions This Session", total) | |
| col2.metric("High Risk Flags", high_risk, delta_color="inverse") | |
| col3.metric("Primary Emotion", top_emotion) | |
| st.divider() | |
| st.subheader("Risk Level Summary") | |
| risk_cols = st.columns(3) | |
| for i, level in enumerate(["High", "Medium", "Low"]): | |
| count = len(df[df["risk_level"] == level]) | |
| icons = {"High": "π¨", "Medium": "β οΈ", "Low": "β "} | |
| risk_cols[i].metric(f"{icons[level]} {level}", count) | |
| st.divider() | |
| chart_col1, chart_col2 = st.columns(2) | |
| with chart_col1: | |
| st.subheader("Emotion Frequency") | |
| st.bar_chart(df["diagnosed_emotion"].value_counts()) | |
| with chart_col2: | |
| st.subheader("Risk Level Distribution") | |
| st.bar_chart(df["risk_level"].value_counts()) | |
| with tab2: | |
| st.subheader("π¬ Last SHAP Word-Level Explanation") | |
| # --- THE FIX: Lazy Loading Execution --- | |
| if st.session_state.get("shap_is_stale", False) and "latest_analyzable_text" in st.session_state: | |
| with st.spinner("Calculating XAI feature importance for the latest message... (This takes a moment)"): | |
| shap_ex = get_shap_explainer() | |
| shap_ex.generate_visual_report(st.session_state.latest_analyzable_text) | |
| st.session_state.shap_is_stale = False | |
| # Render the HTML if it exists | |
| if os.path.exists(SHAP_HTML_PATH): | |
| with open(SHAP_HTML_PATH, "r", encoding="utf-8") as f: | |
| shap_html = f.read() | |
| safe_html = f'<div style="background-color: white; padding: 20px; border-radius: 10px;">{shap_html}</div>' | |
| components.html(safe_html, height=500, scrolling=True) | |
| st.caption(f"Report path: `{SHAP_HTML_PATH}`") | |
| else: | |
| st.info("Send a message in the Chat Companion tab to generate the first XAI report.") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN ROUTER | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.sidebar.title("π‘οΈ Mindguard AI") | |
| st.sidebar.markdown("Welcome to the control panel.") | |
| page = st.sidebar.radio("Navigation", ["π¬ Chat Companion", "π Clinical Dashboard"]) | |
| if page == "π¬ Chat Companion": | |
| render_chat() | |
| elif page == "π Clinical Dashboard": | |
| render_dashboard() |