mindwatch / models /explainability.py
priyadip's picture
Initial MindWatch deployment
4c423a1 verified
"""
MindWatch — Explainability Module
SHAP-based word importance and attention visualization.
"""
import numpy as np
from typing import List, Dict, Tuple
from utils.preprocessing import preprocess_text, tokenize
# Distress-indicative lexicon (research-backed)
DISTRESS_LEXICON = {
"depression": {
"hopeless", "worthless", "empty", "numb", "alone", "sad", "crying",
"tired", "exhausted", "meaningless", "pointless", "nothing", "dark",
"dead", "dying", "hate", "miserable", "suffering", "broken", "lost",
"heavy", "trapped", "useless", "failure", "burden", "guilty",
},
"anxiety": {
"worried", "nervous", "panic", "afraid", "scared", "terrified",
"overthinking", "racing", "shaking", "trembling", "catastrophe",
"dread", "tense", "restless", "obsessing", "paranoid", "phobia",
"fear", "uneasy", "apprehensive", "overwhelmed",
},
"stress": {
"stressed", "overwhelmed", "pressure", "deadline", "burnout",
"exhausting", "frustrating", "overworked", "struggling", "chaos",
"demanding", "impossible", "swamped", "drowning", "cracking",
"snapped", "breaking", "frantic", "hectic",
},
}
# Intensity modifiers
INTENSIFIERS = {"very", "so", "extremely", "completely", "totally", "absolutely", "utterly"}
NEGATORS = {"not", "no", "never", "nothing", "nobody", "none", "cannot", "hardly", "barely"}
def compute_word_importance(
text: str,
predicted_label: str,
probabilities: Dict[str, float],
) -> List[Tuple[str, float]]:
"""
Compute word-level importance scores using lexicon matching + TF-based scoring.
This is a lightweight alternative to full SHAP for the demo.
Returns:
List of (word, importance_score) tuples, sorted by importance.
"""
clean = preprocess_text(text)
words = tokenize(clean)
if not words:
return []
label_confidence = probabilities.get(predicted_label, 0.5)
target_lexicon = set()
for category in DISTRESS_LEXICON.values():
target_lexicon.update(category)
primary_lexicon = DISTRESS_LEXICON.get(predicted_label, set())
scores = []
for i, word in enumerate(words):
score = 0.0
# Primary category match (strongest signal)
if word in primary_lexicon:
score += 0.8
# Any distress lexicon match
elif word in target_lexicon:
score += 0.4
# Negation / intensifier context
if word in NEGATORS:
score += 0.5
if word in INTENSIFIERS:
score += 0.3
# First-person pronouns (self-focus)
if word in {"i", "me", "my", "myself"}:
score += 0.15
# Absolutist language
if word in {"always", "never", "everything", "nothing", "completely"}:
score += 0.35
# Context: intensifier before a distress word
if i > 0 and words[i - 1] in INTENSIFIERS and word in target_lexicon:
score += 0.3
# Scale by prediction confidence
score *= label_confidence
scores.append((word, round(score, 3)))
# Normalize
max_score = max((s for _, s in scores), default=1.0)
if max_score > 0:
scores = [(w, round(s / max_score, 3)) for w, s in scores]
# Sort by importance
scores.sort(key=lambda x: x[1], reverse=True)
return scores
def get_important_words(
text: str,
predicted_label: str,
probabilities: Dict[str, float],
top_k: int = 8,
) -> List[Dict]:
"""
Get top-k important words with their scores and categories.
"""
word_scores = compute_word_importance(text, predicted_label, probabilities)
results = []
seen = set()
for word, score in word_scores:
if word in seen or score <= 0 or len(word) < 2:
continue
seen.add(word)
category = "neutral"
for cat, lexicon in DISTRESS_LEXICON.items():
if word in lexicon:
category = cat
break
if word in NEGATORS:
category = "negation"
if word in INTENSIFIERS:
category = "intensifier"
results.append({
"word": word,
"score": score,
"category": category,
})
if len(results) >= top_k:
break
return results
def format_explanation(
text: str,
predicted_label: str,
probabilities: Dict[str, float],
) -> str:
"""
Generate a human-readable explanation of the prediction.
"""
important = get_important_words(text, predicted_label, probabilities)
confidence = probabilities.get(predicted_label, 0.0)
if not important:
return f"Prediction: {predicted_label.title()} (confidence: {confidence:.1%})\nNo strong distress indicators found in the text."
lines = [
f"Prediction: {predicted_label.title()} (confidence: {confidence:.1%})",
"",
"Key indicators:",
]
for item in important:
bar = "█" * int(item["score"] * 10)
lines.append(f" • \"{item['word']}\" [{item['category']}] {bar} {item['score']:.2f}")
return "\n".join(lines)
if __name__ == "__main__":
test_text = "I feel completely exhausted and nothing seems to work anymore."
probs = {"depression": 0.72, "anxiety": 0.12, "stress": 0.11, "normal": 0.05}
print(format_explanation(test_text, "depression", probs))
print()
print("Important words:")
for w in get_important_words(test_text, "depression", probs):
print(f" {w}")