reframe / distortion_parser.py
Venkatesh Rajagopal
REFRAME: live CBT studio — fine-tuned Gemma 12B on Modal + Cohere voice (ZeroGPU)
4ae4ae8
Raw
History Blame Contribute Delete
2.58 kB
"""Cognitive distortion parser — detects distortion labels in model output."""
from __future__ import annotations
# Trigger phrases that indicate the model is naming a distortion
DISTORTION_TRIGGERS: dict[str, list[str]] = {
"catastrophizing": [
"catastrophiz", "worst case", "jumping to the worst",
"imagining the worst", "end of the world",
],
"overgeneralization": [
"overgeneraliz", "always", "never", "every time",
"\"always\"", "\"never\"", "the word 'always'",
"the word 'never'",
],
"all-or-nothing thinking": [
"all-or-nothing", "black and white", "black-and-white",
"either/or", "all or nothing", "binary thinking",
"no middle ground",
],
"mind-reading": [
"mind-read", "mind read", "assuming what they think",
"guessing their", "assuming they",
"you're reading their mind",
],
"fortune-telling": [
"fortune-tell", "fortune tell", "predicting",
"jumping ahead", "you're predicting",
"crystal ball",
],
"should-statements": [
"should statement", "shoulding yourself", "'should'",
"\"should\"", "must/should", "the word 'should'",
],
"emotional reasoning": [
"emotional reasoning", "feeling it doesn't make it",
"feeling something doesn't make it true",
"just because you feel",
],
"labeling": [
"labeling yourself", "putting a label",
"you're not a", "that's a label",
"calling yourself",
],
"personalization": [
"personaliz", "taking responsibility for",
"blaming yourself for", "not everything is about",
"not your fault",
],
"mental filter": [
"mental filter", "filtering out", "only seeing the negative",
"ignoring the positive", "focusing only on",
],
"disqualifying the positive": [
"disqualifying", "dismissing the positive",
"doesn't count", "that was just luck",
],
}
def detect_distortions(text: str) -> list[str]:
"""Detect cognitive distortions mentioned in model output.
Args:
text: The model's response text.
Returns:
List of distortion names detected (may be empty).
"""
text_lower = text.lower()
detected = []
for distortion, triggers in DISTORTION_TRIGGERS.items():
for trigger in triggers:
if trigger.lower() in text_lower:
if distortion not in detected:
detected.append(distortion)
break
return detected