File size: 5,423 Bytes
616f466
 
 
 
d8d20b8
616f466
 
69123bb
 
 
 
d8d20b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69123bb
d8d20b8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
---
license: apache-2.0
language:
- en
pipeline_tag: text-classification
tags:
- medical
- code
- mental_health
- classifier
library_name: transformers
---

# CBT Cognitive Distortion Classifier

A fine-tuned DistilBERT model for detecting cognitive distortions in text, based on Cognitive Behavioral Therapy (CBT) principles.

## Model Description

This model identifies 5 common cognitive distortions in conversational text:

- **Overgeneralization**: Using "always", "never", "everyone", "nobody"
- **Catastrophizing**: Using "terrible", "awful", "worst", "disaster"
- **Black and White Thinking**: All-or-nothing, either/or patterns
- **Self-Blame**: "My fault", "blame myself", "guilty"
- **Mind Reading**: "They think", "must think", "probably think"

## Model Details

- **Base Model**: `distilbert-base-uncased`
- **Task**: Multi-label classification
- **Number of Labels**: 5
- **Training Data**: 231 samples from mental health conversational data
- **Training Split**: 196 train / 35 test
- **Framework**: HuggingFace Transformers

## Training Performance

| Epoch | Training Loss | Validation Loss |
|-------|--------------|-----------------|
| 1     | 0.1200       | 0.0857          |
| 2     | 0.0322       | 0.0258          |
| 3     | 0.0165       | 0.0129          |
| 4     | 0.0335       | 0.0084          |
| 5     | 0.0079       | 0.0067          |
| 6     | 0.0066       | 0.0056          |
| 7     | 0.0311       | 0.0048          |
| 8     | 0.0523       | 0.0045          |
| 9     | 0.0051       | 0.0044          |
| 10    | 0.0278       | 0.0043          |

**Final Validation Loss**: 0.0043

## Training Configuration

```python
- Epochs: 10
- Batch Size: 8
- Evaluation Strategy: Per epoch
- Optimizer: AdamW (default)
- Max Sequence Length: 128
- Device: GPU (Tesla T4)
```

## Usage

### Loading the Model

```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Load model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained("./cbt_model_final")
tokenizer = AutoTokenizer.from_pretrained("./cbt_model_final")

# Load label mappings
import json
with open("./cbt_model_final/label_config.json", "r") as f:
    label_config = json.load(f)

id2label = label_config["id2label"]
```

### Making Predictions

```python
def predict_distortions(text, threshold=0.5):
    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
    
    # Get predictions
    with torch.no_grad():
        outputs = model(**inputs)
        probabilities = torch.sigmoid(outputs.logits).squeeze()
    
    # Extract distortions above threshold
    detected = []
    for idx, prob in enumerate(probabilities):
        if prob > threshold:
            label = id2label[str(idx)]
            detected.append({
                "distortion": label,
                "confidence": f"{prob.item():.2%}"
            })
    
    return detected

# Example usage
text = "I always mess everything up. This is a disaster!"
distortions = predict_distortions(text)

for d in distortions:
    print(f"{d['distortion']}: {d['confidence']}")
```

### Example Output

```
Input: "I always mess everything up. This is a disaster!"

Detected distortions:
  overgeneralization: 73.45%
  catastrophizing: 68.92%
```

## Model Limitations

⚠️ **Important Considerations**:

- **Small Training Dataset**: Only 231 samples - model may not generalize well to all contexts
- **Rule-Based Labels**: Training labels were created using keyword matching, not expert annotations
- **Prototype Quality**: This is a proof-of-concept model, not production-ready
- **Low Confidence Scores**: Average predictions are 0.25-0.73%, indicating the model is conservative
- **Limited Context**: Only trained on short conversational patterns
- **No Clinical Validation**: Not validated by mental health professionals

## Recommendations for Improvement

1. **Expand Dataset**: Collect more diverse, expert-annotated examples
2. **Better Labeling**: Use clinical experts to label cognitive distortions
3. **Data Augmentation**: Generate synthetic examples for underrepresented patterns
4. **Hyperparameter Tuning**: Experiment with learning rates, batch sizes, epochs
5. **Evaluation Metrics**: Add precision, recall, F1-score tracking
6. **Class Balancing**: Address imbalanced distribution of distortion types

## Files Included

```
cbt_model_final/
β”œβ”€β”€ config.json              # Model configuration
β”œβ”€β”€ model.safetensors        # Model weights
β”œβ”€β”€ tokenizer_config.json    # Tokenizer configuration
β”œβ”€β”€ vocab.txt                # Vocabulary
β”œβ”€β”€ special_tokens_map.json  # Special tokens
β”œβ”€β”€ tokenizer.json           # Tokenizer data
└── label_config.json        # Label mappings
```

## License

This model is based on DistilBERT and inherits its Apache 2.0 license.

## Citation

```
Base Model: DistilBERT
Original Paper: Sanh et al. (2019) - DistilBERT, a distilled version of BERT
Fine-tuning: Custom CBT distortion detection
```

## Disclaimer

⚠️ **This model is for educational and research purposes only.** It should not be used as a substitute for professional mental health diagnosis or treatment. Always consult qualified mental health professionals for clinical applications.

---

**Created**: December 2025  
**Framework**: HuggingFace Transformers  
**Hardware**: Kaggle GPU (Tesla T4)