update readMe
Browse files
README.md
CHANGED
|
@@ -2,7 +2,175 @@
|
|
| 2 |
license: apache-2.0
|
| 3 |
language:
|
| 4 |
- en
|
| 5 |
-
pipeline_tag:
|
| 6 |
tags:
|
| 7 |
- medical
|
| 8 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
license: apache-2.0
|
| 3 |
language:
|
| 4 |
- en
|
| 5 |
+
pipeline_tag: text-classification
|
| 6 |
tags:
|
| 7 |
- medical
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# CBT Cognitive Distortion Classifier
|
| 11 |
+
|
| 12 |
+
A fine-tuned DistilBERT model for detecting cognitive distortions in text, based on Cognitive Behavioral Therapy (CBT) principles.
|
| 13 |
+
|
| 14 |
+
## Model Description
|
| 15 |
+
|
| 16 |
+
This model identifies 5 common cognitive distortions in conversational text:
|
| 17 |
+
|
| 18 |
+
- **Overgeneralization**: Using "always", "never", "everyone", "nobody"
|
| 19 |
+
- **Catastrophizing**: Using "terrible", "awful", "worst", "disaster"
|
| 20 |
+
- **Black and White Thinking**: All-or-nothing, either/or patterns
|
| 21 |
+
- **Self-Blame**: "My fault", "blame myself", "guilty"
|
| 22 |
+
- **Mind Reading**: "They think", "must think", "probably think"
|
| 23 |
+
|
| 24 |
+
## Model Details
|
| 25 |
+
|
| 26 |
+
- **Base Model**: `distilbert-base-uncased`
|
| 27 |
+
- **Task**: Multi-label classification
|
| 28 |
+
- **Number of Labels**: 5
|
| 29 |
+
- **Training Data**: 231 samples from mental health conversational data
|
| 30 |
+
- **Training Split**: 196 train / 35 test
|
| 31 |
+
- **Framework**: HuggingFace Transformers
|
| 32 |
+
|
| 33 |
+
## Training Performance
|
| 34 |
+
|
| 35 |
+
| Epoch | Training Loss | Validation Loss |
|
| 36 |
+
|-------|--------------|-----------------|
|
| 37 |
+
| 1 | 0.1200 | 0.0857 |
|
| 38 |
+
| 2 | 0.0322 | 0.0258 |
|
| 39 |
+
| 3 | 0.0165 | 0.0129 |
|
| 40 |
+
| 4 | 0.0335 | 0.0084 |
|
| 41 |
+
| 5 | 0.0079 | 0.0067 |
|
| 42 |
+
| 6 | 0.0066 | 0.0056 |
|
| 43 |
+
| 7 | 0.0311 | 0.0048 |
|
| 44 |
+
| 8 | 0.0523 | 0.0045 |
|
| 45 |
+
| 9 | 0.0051 | 0.0044 |
|
| 46 |
+
| 10 | 0.0278 | 0.0043 |
|
| 47 |
+
|
| 48 |
+
**Final Validation Loss**: 0.0043
|
| 49 |
+
|
| 50 |
+
## Training Configuration
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
- Epochs: 10
|
| 54 |
+
- Batch Size: 8
|
| 55 |
+
- Evaluation Strategy: Per epoch
|
| 56 |
+
- Optimizer: AdamW (default)
|
| 57 |
+
- Max Sequence Length: 128
|
| 58 |
+
- Device: GPU (Tesla T4)
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## Usage
|
| 62 |
+
|
| 63 |
+
### Loading the Model
|
| 64 |
+
|
| 65 |
+
```python
|
| 66 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 67 |
+
import torch
|
| 68 |
+
|
| 69 |
+
# Load model and tokenizer
|
| 70 |
+
model = AutoModelForSequenceClassification.from_pretrained("./cbt_model_final")
|
| 71 |
+
tokenizer = AutoTokenizer.from_pretrained("./cbt_model_final")
|
| 72 |
+
|
| 73 |
+
# Load label mappings
|
| 74 |
+
import json
|
| 75 |
+
with open("./cbt_model_final/label_config.json", "r") as f:
|
| 76 |
+
label_config = json.load(f)
|
| 77 |
+
|
| 78 |
+
id2label = label_config["id2label"]
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
### Making Predictions
|
| 82 |
+
|
| 83 |
+
```python
|
| 84 |
+
def predict_distortions(text, threshold=0.5):
|
| 85 |
+
# Tokenize input
|
| 86 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
|
| 87 |
+
|
| 88 |
+
# Get predictions
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
outputs = model(**inputs)
|
| 91 |
+
probabilities = torch.sigmoid(outputs.logits).squeeze()
|
| 92 |
+
|
| 93 |
+
# Extract distortions above threshold
|
| 94 |
+
detected = []
|
| 95 |
+
for idx, prob in enumerate(probabilities):
|
| 96 |
+
if prob > threshold:
|
| 97 |
+
label = id2label[str(idx)]
|
| 98 |
+
detected.append({
|
| 99 |
+
"distortion": label,
|
| 100 |
+
"confidence": f"{prob.item():.2%}"
|
| 101 |
+
})
|
| 102 |
+
|
| 103 |
+
return detected
|
| 104 |
+
|
| 105 |
+
# Example usage
|
| 106 |
+
text = "I always mess everything up. This is a disaster!"
|
| 107 |
+
distortions = predict_distortions(text)
|
| 108 |
+
|
| 109 |
+
for d in distortions:
|
| 110 |
+
print(f"{d['distortion']}: {d['confidence']}")
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
### Example Output
|
| 114 |
+
|
| 115 |
+
```
|
| 116 |
+
Input: "I always mess everything up. This is a disaster!"
|
| 117 |
+
|
| 118 |
+
Detected distortions:
|
| 119 |
+
overgeneralization: 73.45%
|
| 120 |
+
catastrophizing: 68.92%
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
## Model Limitations
|
| 124 |
+
|
| 125 |
+
β οΈ **Important Considerations**:
|
| 126 |
+
|
| 127 |
+
- **Small Training Dataset**: Only 231 samples - model may not generalize well to all contexts
|
| 128 |
+
- **Rule-Based Labels**: Training labels were created using keyword matching, not expert annotations
|
| 129 |
+
- **Prototype Quality**: This is a proof-of-concept model, not production-ready
|
| 130 |
+
- **Low Confidence Scores**: Average predictions are 0.25-0.73%, indicating the model is conservative
|
| 131 |
+
- **Limited Context**: Only trained on short conversational patterns
|
| 132 |
+
- **No Clinical Validation**: Not validated by mental health professionals
|
| 133 |
+
|
| 134 |
+
## Recommendations for Improvement
|
| 135 |
+
|
| 136 |
+
1. **Expand Dataset**: Collect more diverse, expert-annotated examples
|
| 137 |
+
2. **Better Labeling**: Use clinical experts to label cognitive distortions
|
| 138 |
+
3. **Data Augmentation**: Generate synthetic examples for underrepresented patterns
|
| 139 |
+
4. **Hyperparameter Tuning**: Experiment with learning rates, batch sizes, epochs
|
| 140 |
+
5. **Evaluation Metrics**: Add precision, recall, F1-score tracking
|
| 141 |
+
6. **Class Balancing**: Address imbalanced distribution of distortion types
|
| 142 |
+
|
| 143 |
+
## Files Included
|
| 144 |
+
|
| 145 |
+
```
|
| 146 |
+
cbt_model_final/
|
| 147 |
+
βββ config.json # Model configuration
|
| 148 |
+
βββ model.safetensors # Model weights
|
| 149 |
+
βββ tokenizer_config.json # Tokenizer configuration
|
| 150 |
+
βββ vocab.txt # Vocabulary
|
| 151 |
+
βββ special_tokens_map.json # Special tokens
|
| 152 |
+
βββ tokenizer.json # Tokenizer data
|
| 153 |
+
βββ label_config.json # Label mappings
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
## License
|
| 157 |
+
|
| 158 |
+
This model is based on DistilBERT and inherits its Apache 2.0 license.
|
| 159 |
+
|
| 160 |
+
## Citation
|
| 161 |
+
|
| 162 |
+
```
|
| 163 |
+
Base Model: DistilBERT
|
| 164 |
+
Original Paper: Sanh et al. (2019) - DistilBERT, a distilled version of BERT
|
| 165 |
+
Fine-tuning: Custom CBT distortion detection
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
## Disclaimer
|
| 169 |
+
|
| 170 |
+
β οΈ **This model is for educational and research purposes only.** It should not be used as a substitute for professional mental health diagnosis or treatment. Always consult qualified mental health professionals for clinical applications.
|
| 171 |
+
|
| 172 |
+
---
|
| 173 |
+
|
| 174 |
+
**Created**: December 2024
|
| 175 |
+
**Framework**: HuggingFace Transformers
|
| 176 |
+
**Hardware**: Kaggle GPU (Tesla T4)
|