YureiYuri commited on
Commit
d8d20b8
Β·
verified Β·
1 Parent(s): 51d6da3

update readMe

Browse files
Files changed (1) hide show
  1. README.md +170 -2
README.md CHANGED
@@ -2,7 +2,175 @@
2
  license: apache-2.0
3
  language:
4
  - en
5
- pipeline_tag: question-answering
6
  tags:
7
  - medical
8
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: apache-2.0
3
  language:
4
  - en
5
+ pipeline_tag: text-classification
6
  tags:
7
  - medical
8
+ ---
9
+
10
+ # CBT Cognitive Distortion Classifier
11
+
12
+ A fine-tuned DistilBERT model for detecting cognitive distortions in text, based on Cognitive Behavioral Therapy (CBT) principles.
13
+
14
+ ## Model Description
15
+
16
+ This model identifies 5 common cognitive distortions in conversational text:
17
+
18
+ - **Overgeneralization**: Using "always", "never", "everyone", "nobody"
19
+ - **Catastrophizing**: Using "terrible", "awful", "worst", "disaster"
20
+ - **Black and White Thinking**: All-or-nothing, either/or patterns
21
+ - **Self-Blame**: "My fault", "blame myself", "guilty"
22
+ - **Mind Reading**: "They think", "must think", "probably think"
23
+
24
+ ## Model Details
25
+
26
+ - **Base Model**: `distilbert-base-uncased`
27
+ - **Task**: Multi-label classification
28
+ - **Number of Labels**: 5
29
+ - **Training Data**: 231 samples from mental health conversational data
30
+ - **Training Split**: 196 train / 35 test
31
+ - **Framework**: HuggingFace Transformers
32
+
33
+ ## Training Performance
34
+
35
+ | Epoch | Training Loss | Validation Loss |
36
+ |-------|--------------|-----------------|
37
+ | 1 | 0.1200 | 0.0857 |
38
+ | 2 | 0.0322 | 0.0258 |
39
+ | 3 | 0.0165 | 0.0129 |
40
+ | 4 | 0.0335 | 0.0084 |
41
+ | 5 | 0.0079 | 0.0067 |
42
+ | 6 | 0.0066 | 0.0056 |
43
+ | 7 | 0.0311 | 0.0048 |
44
+ | 8 | 0.0523 | 0.0045 |
45
+ | 9 | 0.0051 | 0.0044 |
46
+ | 10 | 0.0278 | 0.0043 |
47
+
48
+ **Final Validation Loss**: 0.0043
49
+
50
+ ## Training Configuration
51
+
52
+ ```python
53
+ - Epochs: 10
54
+ - Batch Size: 8
55
+ - Evaluation Strategy: Per epoch
56
+ - Optimizer: AdamW (default)
57
+ - Max Sequence Length: 128
58
+ - Device: GPU (Tesla T4)
59
+ ```
60
+
61
+ ## Usage
62
+
63
+ ### Loading the Model
64
+
65
+ ```python
66
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
67
+ import torch
68
+
69
+ # Load model and tokenizer
70
+ model = AutoModelForSequenceClassification.from_pretrained("./cbt_model_final")
71
+ tokenizer = AutoTokenizer.from_pretrained("./cbt_model_final")
72
+
73
+ # Load label mappings
74
+ import json
75
+ with open("./cbt_model_final/label_config.json", "r") as f:
76
+ label_config = json.load(f)
77
+
78
+ id2label = label_config["id2label"]
79
+ ```
80
+
81
+ ### Making Predictions
82
+
83
+ ```python
84
+ def predict_distortions(text, threshold=0.5):
85
+ # Tokenize input
86
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
87
+
88
+ # Get predictions
89
+ with torch.no_grad():
90
+ outputs = model(**inputs)
91
+ probabilities = torch.sigmoid(outputs.logits).squeeze()
92
+
93
+ # Extract distortions above threshold
94
+ detected = []
95
+ for idx, prob in enumerate(probabilities):
96
+ if prob > threshold:
97
+ label = id2label[str(idx)]
98
+ detected.append({
99
+ "distortion": label,
100
+ "confidence": f"{prob.item():.2%}"
101
+ })
102
+
103
+ return detected
104
+
105
+ # Example usage
106
+ text = "I always mess everything up. This is a disaster!"
107
+ distortions = predict_distortions(text)
108
+
109
+ for d in distortions:
110
+ print(f"{d['distortion']}: {d['confidence']}")
111
+ ```
112
+
113
+ ### Example Output
114
+
115
+ ```
116
+ Input: "I always mess everything up. This is a disaster!"
117
+
118
+ Detected distortions:
119
+ overgeneralization: 73.45%
120
+ catastrophizing: 68.92%
121
+ ```
122
+
123
+ ## Model Limitations
124
+
125
+ ⚠️ **Important Considerations**:
126
+
127
+ - **Small Training Dataset**: Only 231 samples - model may not generalize well to all contexts
128
+ - **Rule-Based Labels**: Training labels were created using keyword matching, not expert annotations
129
+ - **Prototype Quality**: This is a proof-of-concept model, not production-ready
130
+ - **Low Confidence Scores**: Average predictions are 0.25-0.73%, indicating the model is conservative
131
+ - **Limited Context**: Only trained on short conversational patterns
132
+ - **No Clinical Validation**: Not validated by mental health professionals
133
+
134
+ ## Recommendations for Improvement
135
+
136
+ 1. **Expand Dataset**: Collect more diverse, expert-annotated examples
137
+ 2. **Better Labeling**: Use clinical experts to label cognitive distortions
138
+ 3. **Data Augmentation**: Generate synthetic examples for underrepresented patterns
139
+ 4. **Hyperparameter Tuning**: Experiment with learning rates, batch sizes, epochs
140
+ 5. **Evaluation Metrics**: Add precision, recall, F1-score tracking
141
+ 6. **Class Balancing**: Address imbalanced distribution of distortion types
142
+
143
+ ## Files Included
144
+
145
+ ```
146
+ cbt_model_final/
147
+ β”œβ”€β”€ config.json # Model configuration
148
+ β”œβ”€β”€ model.safetensors # Model weights
149
+ β”œβ”€β”€ tokenizer_config.json # Tokenizer configuration
150
+ β”œβ”€β”€ vocab.txt # Vocabulary
151
+ β”œβ”€β”€ special_tokens_map.json # Special tokens
152
+ β”œβ”€β”€ tokenizer.json # Tokenizer data
153
+ └── label_config.json # Label mappings
154
+ ```
155
+
156
+ ## License
157
+
158
+ This model is based on DistilBERT and inherits its Apache 2.0 license.
159
+
160
+ ## Citation
161
+
162
+ ```
163
+ Base Model: DistilBERT
164
+ Original Paper: Sanh et al. (2019) - DistilBERT, a distilled version of BERT
165
+ Fine-tuning: Custom CBT distortion detection
166
+ ```
167
+
168
+ ## Disclaimer
169
+
170
+ ⚠️ **This model is for educational and research purposes only.** It should not be used as a substitute for professional mental health diagnosis or treatment. Always consult qualified mental health professionals for clinical applications.
171
+
172
+ ---
173
+
174
+ **Created**: December 2024
175
+ **Framework**: HuggingFace Transformers
176
+ **Hardware**: Kaggle GPU (Tesla T4)