File size: 5,423 Bytes
616f466 d8d20b8 616f466 69123bb d8d20b8 69123bb d8d20b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
---
license: apache-2.0
language:
- en
pipeline_tag: text-classification
tags:
- medical
- code
- mental_health
- classifier
library_name: transformers
---
# CBT Cognitive Distortion Classifier
A fine-tuned DistilBERT model for detecting cognitive distortions in text, based on Cognitive Behavioral Therapy (CBT) principles.
## Model Description
This model identifies 5 common cognitive distortions in conversational text:
- **Overgeneralization**: Using "always", "never", "everyone", "nobody"
- **Catastrophizing**: Using "terrible", "awful", "worst", "disaster"
- **Black and White Thinking**: All-or-nothing, either/or patterns
- **Self-Blame**: "My fault", "blame myself", "guilty"
- **Mind Reading**: "They think", "must think", "probably think"
## Model Details
- **Base Model**: `distilbert-base-uncased`
- **Task**: Multi-label classification
- **Number of Labels**: 5
- **Training Data**: 231 samples from mental health conversational data
- **Training Split**: 196 train / 35 test
- **Framework**: HuggingFace Transformers
## Training Performance
| Epoch | Training Loss | Validation Loss |
|-------|--------------|-----------------|
| 1 | 0.1200 | 0.0857 |
| 2 | 0.0322 | 0.0258 |
| 3 | 0.0165 | 0.0129 |
| 4 | 0.0335 | 0.0084 |
| 5 | 0.0079 | 0.0067 |
| 6 | 0.0066 | 0.0056 |
| 7 | 0.0311 | 0.0048 |
| 8 | 0.0523 | 0.0045 |
| 9 | 0.0051 | 0.0044 |
| 10 | 0.0278 | 0.0043 |
**Final Validation Loss**: 0.0043
## Training Configuration
```python
- Epochs: 10
- Batch Size: 8
- Evaluation Strategy: Per epoch
- Optimizer: AdamW (default)
- Max Sequence Length: 128
- Device: GPU (Tesla T4)
```
## Usage
### Loading the Model
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Load model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained("./cbt_model_final")
tokenizer = AutoTokenizer.from_pretrained("./cbt_model_final")
# Load label mappings
import json
with open("./cbt_model_final/label_config.json", "r") as f:
label_config = json.load(f)
id2label = label_config["id2label"]
```
### Making Predictions
```python
def predict_distortions(text, threshold=0.5):
# Tokenize input
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
# Get predictions
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.sigmoid(outputs.logits).squeeze()
# Extract distortions above threshold
detected = []
for idx, prob in enumerate(probabilities):
if prob > threshold:
label = id2label[str(idx)]
detected.append({
"distortion": label,
"confidence": f"{prob.item():.2%}"
})
return detected
# Example usage
text = "I always mess everything up. This is a disaster!"
distortions = predict_distortions(text)
for d in distortions:
print(f"{d['distortion']}: {d['confidence']}")
```
### Example Output
```
Input: "I always mess everything up. This is a disaster!"
Detected distortions:
overgeneralization: 73.45%
catastrophizing: 68.92%
```
## Model Limitations
β οΈ **Important Considerations**:
- **Small Training Dataset**: Only 231 samples - model may not generalize well to all contexts
- **Rule-Based Labels**: Training labels were created using keyword matching, not expert annotations
- **Prototype Quality**: This is a proof-of-concept model, not production-ready
- **Low Confidence Scores**: Average predictions are 0.25-0.73%, indicating the model is conservative
- **Limited Context**: Only trained on short conversational patterns
- **No Clinical Validation**: Not validated by mental health professionals
## Recommendations for Improvement
1. **Expand Dataset**: Collect more diverse, expert-annotated examples
2. **Better Labeling**: Use clinical experts to label cognitive distortions
3. **Data Augmentation**: Generate synthetic examples for underrepresented patterns
4. **Hyperparameter Tuning**: Experiment with learning rates, batch sizes, epochs
5. **Evaluation Metrics**: Add precision, recall, F1-score tracking
6. **Class Balancing**: Address imbalanced distribution of distortion types
## Files Included
```
cbt_model_final/
βββ config.json # Model configuration
βββ model.safetensors # Model weights
βββ tokenizer_config.json # Tokenizer configuration
βββ vocab.txt # Vocabulary
βββ special_tokens_map.json # Special tokens
βββ tokenizer.json # Tokenizer data
βββ label_config.json # Label mappings
```
## License
This model is based on DistilBERT and inherits its Apache 2.0 license.
## Citation
```
Base Model: DistilBERT
Original Paper: Sanh et al. (2019) - DistilBERT, a distilled version of BERT
Fine-tuning: Custom CBT distortion detection
```
## Disclaimer
β οΈ **This model is for educational and research purposes only.** It should not be used as a substitute for professional mental health diagnosis or treatment. Always consult qualified mental health professionals for clinical applications.
---
**Created**: December 2025
**Framework**: HuggingFace Transformers
**Hardware**: Kaggle GPU (Tesla T4) |