shallowblueQAQ commited on
Commit
1ce2166
·
verified ·
1 Parent(s): 79b5b02

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +254 -3
README.md CHANGED
@@ -1,3 +1,254 @@
1
- ---
2
- license: cc-by-nc-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ tags:
4
+ - mental-health
5
+ - social-media
6
+ - symptom-identification
7
+ - disease-detection
8
+ ---
9
+
10
+ # 🧩 PsySym: Symptom Identification & Disease Detection System
11
+
12
+ ## 📖 Model Overview
13
+
14
+ The relevant training code is available here:
15
+ [![GitHub](https://img.shields.io/badge/Training_Code-GitHub-black?logo=github&style=flat-square)](https://github.com/blmoistawinde/EMNLP22-PsySym)
16
+
17
+ **What is PsySym?**
18
+
19
+ **PsySym** is a comprehensive framework for interpretable mental disease detection on social media. Unlike "black-box" models that directly predict diseases from text, PsySym first identifies specific psychiatric symptoms defined in clinical manuals (DSM-5) and then uses these symptom profiles to detect mental disorders.
20
+
21
+ <div align="center">
22
+ <img src="./assets_psysym/framework.png" width="600" alt="PsySym Framework" />
23
+ <em>Figure 1: Comparison between pure-text and symptom-assisted mental disease detection.</em>
24
+ </div>
25
+
26
+ This repository contains the models described in the paper **["Symptom Identification for Interpretable Detection of Multiple Mental Disorders on Social Media"](https://aclanthology.org/2022.emnlp-main.677/)** (EMNLP 2022).
27
+
28
+ The system consists of three distinct components:
29
+ 1. **Symptom Relevance Model (`relevance_model`)**: A multi-label classifier that identifies 38 symptom categories from social media sentences.
30
+ 2. **Symptom Status Model (`status_model`)**: A model that determines the uncertainty status of the identified symptoms (e.g., distinguishing "I have insomnia" from "I don't have insomnia").
31
+ 3. **Disease Detection Model (`disease_model`)**: A CNN-based model that predicts mental disorders (e.g., Depression, Anxiety) based on the symptom feature sequences extracted from user timelines.
32
+
33
+ ### Architecture
34
+ * **Relevance & Status Models**: Based on **BERT** (MentalBERT-base) with a linear classification head.
35
+ * **Disease Model**: A custom **K-Max Pooling CNN** that aggregates symptom features across a user's posting history.
36
+
37
+ ## 📂 Repository Structure
38
+
39
+ This repository uses **subfolders** to store the weights for different models.
40
+
41
+ | Subfolder | Task Description | Input | Output |
42
+ | :--- | :--- | :--- | :--- |
43
+ | `relevance_model/` | Identifies which of the 38 symptoms are present. | Text (Sentence) | Logits (Dim: 38) |
44
+ | `status_model/` | Estimates the uncertainty of the symptom. | Text (Sentence) | Logits (Dim: 1) |
45
+ | `disease_model/{disease_name}/` | Detects a specific mental disease (e.g., `depression`, `anxiety`). | Symptom Features Vector | Logits (Dim: 1) |
46
+
47
+ <div align="center">
48
+ <img src="./assets_psysym/pipeline.png" width="600" alt="PsySym Pipeline" />
49
+ <em>Figure 2: The proposed symptom-assisted MDD pipeline.</em>
50
+ </div>
51
+
52
+ ## 🚀 Quick Start (Copy & Run)
53
+
54
+ Since these models use custom architectures, **you must define the model classes locally** before loading the weights.
55
+
56
+ ### 1. Installation
57
+
58
+ ```bash
59
+ pip install transformers torch huggingface_hub
60
+ ```
61
+
62
+ ### 2. Define Model Architectures
63
+
64
+ **A. For Relevance & Status Models (BERT-based)**
65
+
66
+ ```python
67
+ import torch
68
+ from torch import nn
69
+ from transformers import AutoModel, AutoConfig
70
+
71
+ class BERTDiseaseClassifier(nn.Module):
72
+ def __init__(self, model_type, num_symps) -> None:
73
+ super().__init__()
74
+ self.model_type = model_type
75
+ self.num_symps = num_symps
76
+ self.encoder = AutoModel.from_pretrained(model_type)
77
+ self.dropout = nn.Dropout(self.encoder.config.hidden_dropout_prob)
78
+ self.clf = nn.Linear(self.encoder.config.hidden_size, num_symps)
79
+
80
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs):
81
+ outputs = self.encoder(input_ids, attention_mask, token_type_ids)
82
+ x = outputs.last_hidden_state[:, 0, :] # [CLS] pooling
83
+ x = self.dropout(x)
84
+ logits = self.clf(x)
85
+ return logits
86
+ ```
87
+
88
+ **B. For Disease Detection Models (CNN-based)**
89
+
90
+ ```python
91
+ import torch
92
+ from torch import nn
93
+ from torch.nn import functional as F
94
+ from transformers import PreTrainedModel, PretrainedConfig
95
+
96
+ class DiseaseConfig(PretrainedConfig):
97
+ model_type = "kmax_mean_cnn"
98
+ def __init__(self, in_dim=38, filter_num=50, filter_sizes=(2, 3, 4, 5, 6), dropout=0.2, max_pooling_k=5, **kwargs):
99
+ super().__init__(**kwargs)
100
+ self.in_dim = in_dim
101
+ self.filter_num = filter_num
102
+ self.filter_sizes = filter_sizes
103
+ self.dropout = dropout
104
+ self.max_pooling_k = max_pooling_k
105
+
106
+ def kmax_pooling(x, k):
107
+ return x.sort(dim = 2)[0][:, :, -k:]
108
+
109
+ class KMaxMeanCNN(PreTrainedModel):
110
+ config_class = DiseaseConfig
111
+ def __init__(self, config):
112
+ super().__init__(config)
113
+ self.filter_num = config.filter_num
114
+ self.filter_sizes = config.filter_sizes
115
+ self.hidden_size = len(config.filter_sizes) * config.filter_num
116
+ self.max_pooling_k = config.max_pooling_k
117
+ self.convs = nn.ModuleList([nn.Conv1d(config.in_dim, config.filter_num, size) for size in config.filter_sizes])
118
+ self.dropout = nn.Dropout(config.dropout)
119
+ self.fc = nn.Linear(self.hidden_size, 1)
120
+ self.post_init()
121
+
122
+ def forward(self, input_seqs, **kwargs):
123
+ # input_seqs shape: [Batch, SeqLen, InDim]
124
+ input_seqs = input_seqs.transpose(1, 2)
125
+ x = [F.relu(conv(input_seqs)) for conv in self.convs]
126
+ x = [kmax_pooling(item, self.max_pooling_k).mean(2) for item in x]
127
+ x = torch.cat(x, 1)
128
+ x = self.dropout(x)
129
+ logits = self.fc(x)
130
+ return logits
131
+ ```
132
+
133
+ ### 3. Usage Example
134
+
135
+ **A. Loading Relevance & Status Models**
136
+
137
+ Unlike standard BERT models, **mental/mental-bert-base-uncased** is a gated (non-public) model on Hugging Face.
138
+ Users must log in to their Hugging Face account and obtain access permission before downloading it.
139
+
140
+ For convenience and reproducibility, we recommend downloading MentalBERT locally and replacing the MentalBERT path in the code with the local checkpoint path.
141
+
142
+
143
+ ```python
144
+ import torch
145
+ from transformers import AutoConfig, AutoTokenizer
146
+ from huggingface_hub import hf_hub_download, login
147
+ # login() # Required when running in an online environment (e.g., Google Colab)
148
+ # from model import BERTDiseaseClassifier
149
+
150
+ repo_id = "shallowblueQAQ/PsySym-model"
151
+ subfolder = "relevance_model"
152
+ # subfolder = "status_model"
153
+
154
+ # 1. Load Config & Tokenizer
155
+ config = AutoConfig.from_pretrained(repo_id, subfolder=subfolder)
156
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=subfolder)
157
+
158
+ # 2. Initialize Model Architecture
159
+ # model = BERTDiseaseClassifier(model_type="mental/mental-bert-base-uncased", num_symps=len(config.id2label))
160
+ # Replace `/path/to/mental-bert-base-uncased` with the actual local path where MentalBERT is stored.
161
+ model = BERTDiseaseClassifier(model_type="/path/to/mental-bert-base-uncased", num_symps=len(config.id2label))
162
+
163
+ # 3. Load Weights
164
+ weights_path = hf_hub_download(repo_id=repo_id, subfolder=subfolder, filename="pytorch_model.bin")
165
+ model.load_state_dict(torch.load(weights_path, map_location="cpu"))
166
+ model.eval()
167
+
168
+ # 4. Inference
169
+ text = "I had a headache yesterday." if subfolder == "relevance_model" else "Does taking away distractions from some one that has ADD distract the person more or less?"
170
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
171
+
172
+ with torch.no_grad():
173
+ logits = model(**inputs)
174
+ probs = torch.sigmoid(logits)
175
+
176
+ # Display Predictions (Multi-label)
177
+ threshold = 0.5
178
+ for i, prob in enumerate(probs[0]):
179
+ if prob > threshold:
180
+ print(f"Detected: {config.id2label[i]} ({prob:.4f})")
181
+ ```
182
+
183
+ **B. Loading Disease Detection Models**
184
+ Note: The disease model takes symptom feature vectors as input (Shape: [Batch, Seq_Len, 38]), not raw text.
185
+ ```python
186
+ import torch
187
+ from transformers import AutoConfig
188
+ from huggingface_hub import hf_hub_download
189
+ from safetensors.torch import load_file
190
+
191
+ # 1. Define the Model Architecture (Must match model_hf_disease.py)
192
+ # (Copy the KMaxMeanCNN class definition from the "Define Model Architectures" section above)
193
+ # model = KMaxMeanCNN(config) ...
194
+
195
+ # 2. Configuration
196
+ repo_id = "shallowblueQAQ/PsySym-model"
197
+ disease_name = "depression" # Options: depression, anxiety, autism, adhd, schizophrenia, bipolar, ocd, ptsd, eating.
198
+ subfolder = f"disease_model/{disease_name}"
199
+
200
+ # 3. Load Config
201
+ config = DiseaseConfig.from_pretrained(repo_id, subfolder=subfolder)
202
+
203
+ # 4. Initialize Model
204
+ model = KMaxMeanCNN(config)
205
+
206
+ # 5. Load Weights
207
+ weights_path = hf_hub_download(repo_id=repo_id, subfolder=subfolder, filename="model.safetensors")
208
+ state_dict = load_file(weights_path)
209
+ model.load_state_dict(state_dict)
210
+
211
+ model.eval()
212
+
213
+ # 6. Inference Example
214
+ # Input: A sequence of symptom probabilities (from Relevance Model)
215
+ # Shape: [Batch_Size, Sequence_Length, Feature_Dim(38)]
216
+ # Example: Batch=1, User has 50 posts, each post has 38 symptom features
217
+ dummy_input = torch.randn(1, 50, 38)
218
+
219
+ with torch.no_grad():
220
+ # The model expects 'input_seqs'
221
+ outputs = model(input_seqs=dummy_input)
222
+ logits = outputs # Shape: [1, 1]
223
+
224
+ # Convert logits to probability
225
+ prob = torch.sigmoid(logits).item()
226
+
227
+ print(f"Disease Prediction ({disease_name}): {prob:.4f}")
228
+ # Output > 0.5 implies the disease is detected
229
+ ```
230
+
231
+
232
+ ## ⚠️ Ethical Considerations & Limitations
233
+ 1. Research Use Only: This model is intended for research purposes only. It is not a diagnostic tool and must not be used for self-diagnosis or clinical decision-making.
234
+
235
+ 2. Bias & Errors: The model is trained on Reddit data and may reflect specific linguistic styles or biases present in that community. It may not generalize perfectly to other platforms or populations.
236
+
237
+ 3. Data Privacy: The training data involves sensitive mental health disclosures. While the model weights do not directly expose user data, outputs should be handled with care to protect user privacy.
238
+
239
+ ## Data Availability
240
+ This model was trained on **PsySym**, a subset derived from the **[SMHD (Self-reported Mental Health Diagnoses)](https://aclanthology.org/C18-1126/)** dataset.
241
+
242
+ **Due to the strict Data Usage Agreement of SMHD, we cannot publish the original dataset.** Researchers interested in the data must apply for access directly from the creators of [SMHD (Cohan et al., 2018)](https://ir.cs.georgetown.edu/resources/).
243
+
244
+ ## Citation
245
+ If you use this model, please cite our paper:
246
+ ```bibtex
247
+ @inproceedings{zhang2022symptom,
248
+ title={Symptom Identification for Interpretable Detection of Multiple Mental Disorders on Social Media},
249
+ author={Zhang, Zhiling and Chen, Siyuan and Wu, Mengyue and Zhu, Kenny},
250
+ booktitle={Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing},
251
+ pages={9970--9985},
252
+ year={2022}
253
+ }
254
+ ```