|
|
--- |
|
|
license: apache-2.0 |
|
|
language: |
|
|
- en |
|
|
pipeline_tag: text-classification |
|
|
tags: |
|
|
- medical |
|
|
- code |
|
|
- mental_health |
|
|
- classifier |
|
|
library_name: transformers |
|
|
--- |
|
|
|
|
|
# CBT Cognitive Distortion Classifier |
|
|
|
|
|
A fine-tuned DistilBERT model for detecting cognitive distortions in text, based on Cognitive Behavioral Therapy (CBT) principles. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model identifies 5 common cognitive distortions in conversational text: |
|
|
|
|
|
- **Overgeneralization**: Using "always", "never", "everyone", "nobody" |
|
|
- **Catastrophizing**: Using "terrible", "awful", "worst", "disaster" |
|
|
- **Black and White Thinking**: All-or-nothing, either/or patterns |
|
|
- **Self-Blame**: "My fault", "blame myself", "guilty" |
|
|
- **Mind Reading**: "They think", "must think", "probably think" |
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Base Model**: `distilbert-base-uncased` |
|
|
- **Task**: Multi-label classification |
|
|
- **Number of Labels**: 5 |
|
|
- **Training Data**: 231 samples from mental health conversational data |
|
|
- **Training Split**: 196 train / 35 test |
|
|
- **Framework**: HuggingFace Transformers |
|
|
|
|
|
## Training Performance |
|
|
|
|
|
| Epoch | Training Loss | Validation Loss | |
|
|
|-------|--------------|-----------------| |
|
|
| 1 | 0.1200 | 0.0857 | |
|
|
| 2 | 0.0322 | 0.0258 | |
|
|
| 3 | 0.0165 | 0.0129 | |
|
|
| 4 | 0.0335 | 0.0084 | |
|
|
| 5 | 0.0079 | 0.0067 | |
|
|
| 6 | 0.0066 | 0.0056 | |
|
|
| 7 | 0.0311 | 0.0048 | |
|
|
| 8 | 0.0523 | 0.0045 | |
|
|
| 9 | 0.0051 | 0.0044 | |
|
|
| 10 | 0.0278 | 0.0043 | |
|
|
|
|
|
**Final Validation Loss**: 0.0043 |
|
|
|
|
|
## Training Configuration |
|
|
|
|
|
```python |
|
|
- Epochs: 10 |
|
|
- Batch Size: 8 |
|
|
- Evaluation Strategy: Per epoch |
|
|
- Optimizer: AdamW (default) |
|
|
- Max Sequence Length: 128 |
|
|
- Device: GPU (Tesla T4) |
|
|
``` |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Loading the Model |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
|
|
|
# Load model and tokenizer |
|
|
model = AutoModelForSequenceClassification.from_pretrained("./cbt_model_final") |
|
|
tokenizer = AutoTokenizer.from_pretrained("./cbt_model_final") |
|
|
|
|
|
# Load label mappings |
|
|
import json |
|
|
with open("./cbt_model_final/label_config.json", "r") as f: |
|
|
label_config = json.load(f) |
|
|
|
|
|
id2label = label_config["id2label"] |
|
|
``` |
|
|
|
|
|
### Making Predictions |
|
|
|
|
|
```python |
|
|
def predict_distortions(text, threshold=0.5): |
|
|
# Tokenize input |
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128) |
|
|
|
|
|
# Get predictions |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probabilities = torch.sigmoid(outputs.logits).squeeze() |
|
|
|
|
|
# Extract distortions above threshold |
|
|
detected = [] |
|
|
for idx, prob in enumerate(probabilities): |
|
|
if prob > threshold: |
|
|
label = id2label[str(idx)] |
|
|
detected.append({ |
|
|
"distortion": label, |
|
|
"confidence": f"{prob.item():.2%}" |
|
|
}) |
|
|
|
|
|
return detected |
|
|
|
|
|
# Example usage |
|
|
text = "I always mess everything up. This is a disaster!" |
|
|
distortions = predict_distortions(text) |
|
|
|
|
|
for d in distortions: |
|
|
print(f"{d['distortion']}: {d['confidence']}") |
|
|
``` |
|
|
|
|
|
### Example Output |
|
|
|
|
|
``` |
|
|
Input: "I always mess everything up. This is a disaster!" |
|
|
|
|
|
Detected distortions: |
|
|
overgeneralization: 73.45% |
|
|
catastrophizing: 68.92% |
|
|
``` |
|
|
|
|
|
## Model Limitations |
|
|
|
|
|
β οΈ **Important Considerations**: |
|
|
|
|
|
- **Small Training Dataset**: Only 231 samples - model may not generalize well to all contexts |
|
|
- **Rule-Based Labels**: Training labels were created using keyword matching, not expert annotations |
|
|
- **Prototype Quality**: This is a proof-of-concept model, not production-ready |
|
|
- **Low Confidence Scores**: Average predictions are 0.25-0.73%, indicating the model is conservative |
|
|
- **Limited Context**: Only trained on short conversational patterns |
|
|
- **No Clinical Validation**: Not validated by mental health professionals |
|
|
|
|
|
## Recommendations for Improvement |
|
|
|
|
|
1. **Expand Dataset**: Collect more diverse, expert-annotated examples |
|
|
2. **Better Labeling**: Use clinical experts to label cognitive distortions |
|
|
3. **Data Augmentation**: Generate synthetic examples for underrepresented patterns |
|
|
4. **Hyperparameter Tuning**: Experiment with learning rates, batch sizes, epochs |
|
|
5. **Evaluation Metrics**: Add precision, recall, F1-score tracking |
|
|
6. **Class Balancing**: Address imbalanced distribution of distortion types |
|
|
|
|
|
## Files Included |
|
|
|
|
|
``` |
|
|
cbt_model_final/ |
|
|
βββ config.json # Model configuration |
|
|
βββ model.safetensors # Model weights |
|
|
βββ tokenizer_config.json # Tokenizer configuration |
|
|
βββ vocab.txt # Vocabulary |
|
|
βββ special_tokens_map.json # Special tokens |
|
|
βββ tokenizer.json # Tokenizer data |
|
|
βββ label_config.json # Label mappings |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
This model is based on DistilBERT and inherits its Apache 2.0 license. |
|
|
|
|
|
## Citation |
|
|
|
|
|
``` |
|
|
Base Model: DistilBERT |
|
|
Original Paper: Sanh et al. (2019) - DistilBERT, a distilled version of BERT |
|
|
Fine-tuning: Custom CBT distortion detection |
|
|
``` |
|
|
|
|
|
## Disclaimer |
|
|
|
|
|
β οΈ **This model is for educational and research purposes only.** It should not be used as a substitute for professional mental health diagnosis or treatment. Always consult qualified mental health professionals for clinical applications. |
|
|
|
|
|
--- |
|
|
|
|
|
**Created**: December 2025 |
|
|
**Framework**: HuggingFace Transformers |
|
|
**Hardware**: Kaggle GPU (Tesla T4) |