"""Cognitive distortion parser — detects distortion labels in model output.""" from __future__ import annotations # Trigger phrases that indicate the model is naming a distortion DISTORTION_TRIGGERS: dict[str, list[str]] = { "catastrophizing": [ "catastrophiz", "worst case", "jumping to the worst", "imagining the worst", "end of the world", ], "overgeneralization": [ "overgeneraliz", "always", "never", "every time", "\"always\"", "\"never\"", "the word 'always'", "the word 'never'", ], "all-or-nothing thinking": [ "all-or-nothing", "black and white", "black-and-white", "either/or", "all or nothing", "binary thinking", "no middle ground", ], "mind-reading": [ "mind-read", "mind read", "assuming what they think", "guessing their", "assuming they", "you're reading their mind", ], "fortune-telling": [ "fortune-tell", "fortune tell", "predicting", "jumping ahead", "you're predicting", "crystal ball", ], "should-statements": [ "should statement", "shoulding yourself", "'should'", "\"should\"", "must/should", "the word 'should'", ], "emotional reasoning": [ "emotional reasoning", "feeling it doesn't make it", "feeling something doesn't make it true", "just because you feel", ], "labeling": [ "labeling yourself", "putting a label", "you're not a", "that's a label", "calling yourself", ], "personalization": [ "personaliz", "taking responsibility for", "blaming yourself for", "not everything is about", "not your fault", ], "mental filter": [ "mental filter", "filtering out", "only seeing the negative", "ignoring the positive", "focusing only on", ], "disqualifying the positive": [ "disqualifying", "dismissing the positive", "doesn't count", "that was just luck", ], } def detect_distortions(text: str) -> list[str]: """Detect cognitive distortions mentioned in model output. Args: text: The model's response text. Returns: List of distortion names detected (may be empty). """ text_lower = text.lower() detected = [] for distortion, triggers in DISTORTION_TRIGGERS.items(): for trigger in triggers: if trigger.lower() in text_lower: if distortion not in detected: detected.append(distortion) break return detected