""" MindWatch — Explainability Module SHAP-based word importance and attention visualization. """ import numpy as np from typing import List, Dict, Tuple from utils.preprocessing import preprocess_text, tokenize # Distress-indicative lexicon (research-backed) DISTRESS_LEXICON = { "depression": { "hopeless", "worthless", "empty", "numb", "alone", "sad", "crying", "tired", "exhausted", "meaningless", "pointless", "nothing", "dark", "dead", "dying", "hate", "miserable", "suffering", "broken", "lost", "heavy", "trapped", "useless", "failure", "burden", "guilty", }, "anxiety": { "worried", "nervous", "panic", "afraid", "scared", "terrified", "overthinking", "racing", "shaking", "trembling", "catastrophe", "dread", "tense", "restless", "obsessing", "paranoid", "phobia", "fear", "uneasy", "apprehensive", "overwhelmed", }, "stress": { "stressed", "overwhelmed", "pressure", "deadline", "burnout", "exhausting", "frustrating", "overworked", "struggling", "chaos", "demanding", "impossible", "swamped", "drowning", "cracking", "snapped", "breaking", "frantic", "hectic", }, } # Intensity modifiers INTENSIFIERS = {"very", "so", "extremely", "completely", "totally", "absolutely", "utterly"} NEGATORS = {"not", "no", "never", "nothing", "nobody", "none", "cannot", "hardly", "barely"} def compute_word_importance( text: str, predicted_label: str, probabilities: Dict[str, float], ) -> List[Tuple[str, float]]: """ Compute word-level importance scores using lexicon matching + TF-based scoring. This is a lightweight alternative to full SHAP for the demo. Returns: List of (word, importance_score) tuples, sorted by importance. """ clean = preprocess_text(text) words = tokenize(clean) if not words: return [] label_confidence = probabilities.get(predicted_label, 0.5) target_lexicon = set() for category in DISTRESS_LEXICON.values(): target_lexicon.update(category) primary_lexicon = DISTRESS_LEXICON.get(predicted_label, set()) scores = [] for i, word in enumerate(words): score = 0.0 # Primary category match (strongest signal) if word in primary_lexicon: score += 0.8 # Any distress lexicon match elif word in target_lexicon: score += 0.4 # Negation / intensifier context if word in NEGATORS: score += 0.5 if word in INTENSIFIERS: score += 0.3 # First-person pronouns (self-focus) if word in {"i", "me", "my", "myself"}: score += 0.15 # Absolutist language if word in {"always", "never", "everything", "nothing", "completely"}: score += 0.35 # Context: intensifier before a distress word if i > 0 and words[i - 1] in INTENSIFIERS and word in target_lexicon: score += 0.3 # Scale by prediction confidence score *= label_confidence scores.append((word, round(score, 3))) # Normalize max_score = max((s for _, s in scores), default=1.0) if max_score > 0: scores = [(w, round(s / max_score, 3)) for w, s in scores] # Sort by importance scores.sort(key=lambda x: x[1], reverse=True) return scores def get_important_words( text: str, predicted_label: str, probabilities: Dict[str, float], top_k: int = 8, ) -> List[Dict]: """ Get top-k important words with their scores and categories. """ word_scores = compute_word_importance(text, predicted_label, probabilities) results = [] seen = set() for word, score in word_scores: if word in seen or score <= 0 or len(word) < 2: continue seen.add(word) category = "neutral" for cat, lexicon in DISTRESS_LEXICON.items(): if word in lexicon: category = cat break if word in NEGATORS: category = "negation" if word in INTENSIFIERS: category = "intensifier" results.append({ "word": word, "score": score, "category": category, }) if len(results) >= top_k: break return results def format_explanation( text: str, predicted_label: str, probabilities: Dict[str, float], ) -> str: """ Generate a human-readable explanation of the prediction. """ important = get_important_words(text, predicted_label, probabilities) confidence = probabilities.get(predicted_label, 0.0) if not important: return f"Prediction: {predicted_label.title()} (confidence: {confidence:.1%})\nNo strong distress indicators found in the text." lines = [ f"Prediction: {predicted_label.title()} (confidence: {confidence:.1%})", "", "Key indicators:", ] for item in important: bar = "█" * int(item["score"] * 10) lines.append(f" • \"{item['word']}\" [{item['category']}] {bar} {item['score']:.2f}") return "\n".join(lines) if __name__ == "__main__": test_text = "I feel completely exhausted and nothing seems to work anymore." probs = {"depression": 0.72, "anxiety": 0.12, "stress": 0.11, "normal": 0.05} print(format_explanation(test_text, "depression", probs)) print() print("Important words:") for w in get_important_words(test_text, "depression", probs): print(f" {w}")