Spaces:
Running
Running
File size: 15,645 Bytes
7682193 6295bd7 7682193 5603305 7682193 46f84b9 7682193 6295bd7 7682193 6295bd7 7682193 6295bd7 7682193 46f84b9 7682193 6295bd7 7682193 6295bd7 7682193 6295bd7 7682193 6295bd7 7682193 6295bd7 7682193 6295bd7 7682193 6295bd7 7682193 6295bd7 7682193 6295bd7 7682193 6295bd7 7682193 6295bd7 7682193 9ab0852 6295bd7 9ab0852 7682193 5603305 7682193 f1e20e4 7682193 6295bd7 7682193 5603305 7682193 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 | # ============================================================
# FILE: app.py (Consolidated Root)
# PURPOSE: Single-file deployment for Safeguard AI.
# Contains UI, Navigation, Groq Bot Logic, and In-Memory Analytics.
# *UPDATED: Implements Lazy-Loading for SHAP XAI*
# ============================================================
import os
import sys
import pandas as pd
import streamlit as st
import streamlit.components.v1 as components
from dotenv import load_dotenv
from groq import Groq
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# PATH SETUP & BACKEND IMPORTS
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Add project root to path so we can import the heavy ML engines from src/
project_root = os.path.dirname(os.path.abspath(__file__))
if project_root not in sys.path:
sys.path.append(project_root)
# Import the core AI engines (ensure your src/ folder is uploaded to HF!)
try:
from src.core_model.predict import MindGuardPredictor
from src.rag_engine.retriever import MindGuardRetriever
from src.audio.speech_to_text import MindGuardAudioProcessor
from src.explainability.shap_explainer import MindGuardSHAPExplainer
except ImportError as e:
st.error(f"Failed to import backend modules. Ensure the 'src' folder is uploaded. Details: {e}")
st.stop()
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# PAGE CONFIGURATION
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
st.set_page_config(page_title="Mindguard AI", page_icon="π‘οΈ", layout="wide")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# CORE CHATBOT ENGINE
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class SafeguardChatbot:
"""Orchestrates prediction, semantic search, and Groq LLM response."""
def __init__(self):
# 1. Load API Key
load_dotenv(os.path.join(project_root, ".env"))
api_key = os.environ.get("GROQ_API_KEY")
if not api_key:
st.error("β GROQ_API_KEY not found in Hugging Face Secrets!")
st.stop()
self.client = Groq(api_key=api_key)
# 2. Wake up tools
self.predictor = MindGuardPredictor()
self.retriever = MindGuardRetriever()
self.audio_processor = MindGuardAudioProcessor()
self.system_prompt = """
You are Safeguard AI, a highly empathetic, clinical-grade mental health AI.
Your goal is to de-escalate emotional distress and provide actionable coping strategies.
STRICT RULES:
1. NEVER hallucinate medical advice. ONLY use the 'Clinical Strategy' provided in the prompt.
2. Keep your response conversational, warm, and easy to read (use short paragraphs).
3. Do not sound like a robot reading a textbook. Weave the clinical strategy naturally into your empathy.
4. If the Risk Level is 'High', prioritize grounding the user immediately.
"""
def generate_response(self, user_input, chat_history_list):
"""Pipeline: Predict -> Retrieve -> Remember -> Generate."""
# 1. Core Model Prediction
prediction = self.predictor.predict(user_input)
emotion = prediction['emotion']
risk = prediction['risk_level']
# 2. RAG Retrieval
strategy = self.retriever.get_coping_strategy(
user_query=user_input,
emotion_filter=emotion
)
# 3. Extract recent history directly from Streamlit session state
history_text = "No previous conversation."
if chat_history_list:
recent = chat_history_list[-4:] # Grab last two exchanges
history_text = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in recent])
# 4. Prompt Engineering
augmented_prompt = f"""
--- RECENT CONVERSATION HISTORY ---
{history_text}
--- CURRENT SITUATION ---
User's New Message: "{user_input}"
AI Core Diagnosis: {emotion}
Assessed Risk Level: {risk}
Required Clinical Strategy to Teach the User:
{strategy}
Draft your response to the user's new message now:
"""
# 5. Groq Generation
chat_completion = self.client.chat.completions.create(
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": augmented_prompt}
],
model="llama-3.3-70b-versatile",
temperature=0.3,
)
final_response = chat_completion.choices[0].message.content
return final_response, emotion, risk
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# RESOURCE CACHING
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@st.cache_resource
def get_safeguard_bot():
return SafeguardChatbot()
@st.cache_resource
def get_shap_explainer():
return MindGuardSHAPExplainer()
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# UI HELPERS (Badges & Formatting)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
RISK_COLORS = {
"High": ("#ff4b4b", "white"), "Medium": ("#ffa500", "white"), "Low": ("#21c354", "white")
}
EMOTION_COLORS = {
"Suicidal": "#ff4b4b", "Depression": "#e05260", "Anxiety": "#e07052", "Bipolar": "#c0392b",
"Stress": "#e67e22", "Personality disorder": "#9b59b6", "joy": "#21c354", "love": "#2ecc71",
"gratitude": "#27ae60", "optimism": "#16a085", "Normal": "#3498db", "neutral": "#3498db",
"sadness": "#e08000", "fear": "#e74c3c", "anger": "#c0392b"
}
def _emotion_badge(emotion: str) -> str:
color = EMOTION_COLORS.get(emotion, "#555555")
return f'<span style="background:{color};color:white;padding:3px 10px;border-radius:12px;font-size:13px;font-weight:600;">π§ {emotion}</span>'
def _risk_badge(risk: str) -> str:
bg, fg = RISK_COLORS.get(risk, ("#888888", "white"))
icon = {"High": "π¨", "Medium": "β οΈ", "Low": "β
"}.get(risk, "β’")
return f'<span style="background:{bg};color:{fg};padding:3px 10px;border-radius:12px;font-size:13px;font-weight:600;">{icon} Risk: {risk}</span>'
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# PAGE 1: CHAT COMPANION
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def render_chat():
st.title("π§ Mindguard AI Companion")
st.markdown("Your clinical-grade, empathetic AI. Type a message or upload a voice note.")
bot = get_safeguard_bot()
if "messages" not in st.session_state:
st.session_state.messages = []
# SIDEBAR: Audio & Controls
with st.sidebar:
st.header("ποΈ Voice Input")
audio_value = st.audio_input("Record a voice note")
if audio_value:
current_audio_size = len(audio_value.getvalue())
if ("last_audio_size" not in st.session_state or st.session_state.last_audio_size != current_audio_size):
st.session_state.last_audio_size = current_audio_size
os.makedirs(os.path.join(project_root, "data", "raw"), exist_ok=True)
temp_path = os.path.join(project_root, "data", "raw", "temp_streamlit.wav")
with open(temp_path, "wb") as f:
f.write(audio_value.getbuffer())
with st.spinner("Transcribing via Whisperβ¦"):
transcribed_text = bot.audio_processor.transcribe(temp_path)
# --- DEFERRED EXECUTION FLAGS ---
st.session_state.latest_analyzable_text = transcribed_text
st.session_state.shap_is_stale = True
response, emotion, risk = bot.generate_response(transcribed_text, st.session_state.messages)
st.session_state.messages.append({"role": "user", "content": f"π€ **Voice Note:** *{transcribed_text}*"})
st.session_state.messages.append({
"role": "assistant", "content": response, "emotion": emotion, "risk": risk
})
st.divider()
if st.button("ποΈ Clear Chat"):
st.session_state.messages = []
st.session_state.pop("last_audio_size", None)
st.session_state.pop("latest_analyzable_text", None)
st.session_state.pop("shap_is_stale", None)
st.rerun()
# MAIN WINDOW: Chat History
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
if msg["role"] == "assistant" and "emotion" in msg:
col1, col2 = st.columns([1, 1])
with col1: st.markdown(_emotion_badge(msg["emotion"]), unsafe_allow_html=True)
with col2: st.markdown(_risk_badge(msg["risk"]), unsafe_allow_html=True)
# MAIN WINDOW: Text Input
if prompt := st.chat_input("How are you feeling right now?"):
st.session_state.messages.append({"role": "user", "content": prompt})
# --- DEFERRED EXECUTION FLAGS ---
st.session_state.latest_analyzable_text = prompt
st.session_state.shap_is_stale = True
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Diagnosing and retrieving clinical strategyβ¦"):
response, emotion, risk = bot.generate_response(prompt, st.session_state.messages)
st.markdown(response)
col1, col2 = st.columns([1, 1])
with col1: st.markdown(_emotion_badge(emotion), unsafe_allow_html=True)
with col2: st.markdown(_risk_badge(risk), unsafe_allow_html=True)
st.session_state.messages.append({
"role": "assistant", "content": response, "emotion": emotion, "risk": risk
})
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# PAGE 2: CLINICAL DASHBOARD
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def render_dashboard():
st.title("π Clinical Overview")
st.markdown("Real-time emotional tracking and XAI reports for the current session.")
tab1, tab2 = st.tabs(["π Analytics", "π¬ XAI Report"])
SHAP_HTML_PATH = os.path.join(project_root, "artifacts", "shap_report.html")
with tab1:
history_data = []
for msg in st.session_state.get("messages", []):
if msg["role"] == "assistant" and "emotion" in msg:
history_data.append({
"diagnosed_emotion": msg["emotion"],
"risk_level": msg["risk"],
"timestamp": pd.Timestamp.now().strftime("%H:%M:%S")
})
if not history_data:
st.info("No data yet. Start chatting to generate live analytics.")
return
df = pd.DataFrame(history_data)
total = len(df)
high_risk = len(df[df["risk_level"] == "High"])
top_emotion = df["diagnosed_emotion"].mode()[0] if not df.empty else "N/A"
col1, col2, col3 = st.columns(3)
col1.metric("Interactions This Session", total)
col2.metric("High Risk Flags", high_risk, delta_color="inverse")
col3.metric("Primary Emotion", top_emotion)
st.divider()
st.subheader("Risk Level Summary")
risk_cols = st.columns(3)
for i, level in enumerate(["High", "Medium", "Low"]):
count = len(df[df["risk_level"] == level])
icons = {"High": "π¨", "Medium": "β οΈ", "Low": "β
"}
risk_cols[i].metric(f"{icons[level]} {level}", count)
st.divider()
chart_col1, chart_col2 = st.columns(2)
with chart_col1:
st.subheader("Emotion Frequency")
st.bar_chart(df["diagnosed_emotion"].value_counts())
with chart_col2:
st.subheader("Risk Level Distribution")
st.bar_chart(df["risk_level"].value_counts())
with tab2:
st.subheader("π¬ Last SHAP Word-Level Explanation")
# --- THE FIX: Lazy Loading Execution ---
if st.session_state.get("shap_is_stale", False) and "latest_analyzable_text" in st.session_state:
with st.spinner("Calculating XAI feature importance for the latest message... (This takes a moment)"):
shap_ex = get_shap_explainer()
shap_ex.generate_visual_report(st.session_state.latest_analyzable_text)
st.session_state.shap_is_stale = False
# Render the HTML if it exists
if os.path.exists(SHAP_HTML_PATH):
with open(SHAP_HTML_PATH, "r", encoding="utf-8") as f:
shap_html = f.read()
safe_html = f'<div style="background-color: white; padding: 20px; border-radius: 10px;">{shap_html}</div>'
components.html(safe_html, height=500, scrolling=True)
st.caption(f"Report path: `{SHAP_HTML_PATH}`")
else:
st.info("Send a message in the Chat Companion tab to generate the first XAI report.")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# MAIN ROUTER
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
st.sidebar.title("π‘οΈ Mindguard AI")
st.sidebar.markdown("Welcome to the control panel.")
page = st.sidebar.radio("Navigation", ["π¬ Chat Companion", "π Clinical Dashboard"])
if page == "π¬ Chat Companion":
render_chat()
elif page == "π Clinical Dashboard":
render_dashboard() |