File size: 12,071 Bytes
1ce2166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4074b5d
1ce2166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f19fd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ce2166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24b6bd0
1ce2166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
---
license: cc-by-nc-4.0
tags:
- mental-health
- social-media
- symptom-identification
- disease-detection
---

# 🧩 PsySym: Symptom Identification & Disease Detection System

## 📖 Model Overview

The relevant training code is available here:
[![GitHub](https://img.shields.io/badge/Training_Code-GitHub-black?logo=github&style=flat-square)](https://github.com/blmoistawinde/EMNLP22-PsySym)

**What is PsySym?**

**PsySym** is a comprehensive framework for interpretable mental disease detection on social media. Unlike "black-box" models that directly predict diseases from text, PsySym first identifies specific psychiatric symptoms defined in clinical manuals (DSM-5) and then uses these symptom profiles to detect mental disorders.

<div align="center">
    <img src="./assets_psysym/framework.png" width="600" alt="PsySym Framework" />
    <em>Figure 1: Comparison between pure-text and symptom-assisted mental disease detection.</em>
</div>

This repository contains the models described in the paper **["Symptom Identification for Interpretable Detection of Multiple Mental Disorders on Social Media"](https://aclanthology.org/2022.emnlp-main.677/)** (EMNLP 2022).

The system consists of three distinct components:
1.  **Symptom Relevance Model (`relevance_model`)**: A multi-label classifier that identifies 38 symptom categories from social media sentences.
2.  **Symptom Status Model (`status_model`)**: A model that determines the uncertainty status of the identified symptoms (e.g., distinguishing "I have insomnia" from "I don't have insomnia").
3.  **Disease Detection Model (`disease_model`)**: A CNN-based model that predicts mental disorders (e.g., Depression, Anxiety) based on the symptom feature sequences extracted from user timelines.

### Architecture
* **Relevance & Status Models**: Based on **BERT** (MentalBERT-base) with a linear classification head.
* **Disease Model**: A custom **CNN** that aggregates symptom features across a user's posting history.

## 📂 Repository Structure

This repository uses **subfolders** to store the weights for different models.

| Subfolder | Task Description | Input | Output |
| :--- | :--- | :--- | :--- |
| `relevance_model/` | Identifies which of the 38 symptoms are present. | Text (Sentence) | Logits (Dim: 38) |
| `status_model/` | Estimates the uncertainty of the symptom. | Text (Sentence) | Logits (Dim: 1) |
| `disease_model/{disease_name}/` | Detects a specific mental disease (e.g., `depression`, `anxiety`). | Symptom Features Vector | Logits (Dim: 1) |

<div align="center">
    <img src="./assets_psysym/pipeline.png" width="600" alt="PsySym Pipeline" />
    <em>Figure 2: The proposed symptom-assisted MDD pipeline.</em>
</div>

## 🚀 Quick Start (Copy & Run)

Since these models use custom architectures, **you must define the model classes locally** before loading the weights.

### 1. Installation

```bash
pip install transformers torch huggingface_hub
```

### 2. Define Model Architectures

**A. For Relevance & Status Models (BERT-based)**

```python
import torch
from torch import nn
from transformers import AutoModel, AutoConfig

class BERTDiseaseClassifier(nn.Module):
    def __init__(self, model_type, num_symps) -> None:
        super().__init__()
        self.model_type = model_type
        self.num_symps = num_symps
        self.encoder = AutoModel.from_pretrained(model_type)
        self.dropout = nn.Dropout(self.encoder.config.hidden_dropout_prob)
        self.clf = nn.Linear(self.encoder.config.hidden_size, num_symps)

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs):
        outputs = self.encoder(input_ids, attention_mask, token_type_ids)
        x = outputs.last_hidden_state[:, 0, :]   # [CLS] pooling
        x = self.dropout(x)
        logits = self.clf(x)
        return logits
```

**B. For Disease Detection Models (CNN-based)**

```python
import torch
from torch import nn
from torch.nn import functional as F
from transformers import PreTrainedModel, PretrainedConfig

class DiseaseConfig(PretrainedConfig):
    model_type = "kmax_mean_cnn"
    def __init__(self, in_dim=38, filter_num=50, filter_sizes=(2, 3, 4, 5, 6), dropout=0.2, max_pooling_k=5, **kwargs):
        super().__init__(**kwargs)
        self.in_dim = in_dim
        self.filter_num = filter_num
        self.filter_sizes = filter_sizes
        self.dropout = dropout
        self.max_pooling_k = max_pooling_k

def kmax_pooling(x, k):
    return x.sort(dim = 2)[0][:, :, -k:]

class KMaxMeanCNN(PreTrainedModel):
    config_class = DiseaseConfig
    def __init__(self, config):
        super().__init__(config)
        self.filter_num = config.filter_num
        self.filter_sizes = config.filter_sizes
        self.hidden_size = len(config.filter_sizes) * config.filter_num
        self.max_pooling_k = config.max_pooling_k
        self.convs = nn.ModuleList([nn.Conv1d(config.in_dim, config.filter_num, size) for size in config.filter_sizes])
        self.dropout = nn.Dropout(config.dropout)
        self.fc = nn.Linear(self.hidden_size, 1)
        self.post_init()

    def forward(self, input_seqs, **kwargs):
        # input_seqs shape: [Batch, SeqLen, InDim]
        input_seqs = input_seqs.transpose(1, 2) 
        x = [F.relu(conv(input_seqs)) for conv in self.convs]
        x = [kmax_pooling(item, self.max_pooling_k).mean(2) for item in x]
        x = torch.cat(x, 1)
        x = self.dropout(x)
        logits = self.fc(x)
        return logits
```

### 3. Usage Example

**A. Loading Relevance & Status Models**

Unlike standard BERT models, **mental/mental-bert-base-uncased** is a gated (non-public) model on Hugging Face.
Users must log in to their Hugging Face account and obtain access permission before downloading it.

For convenience and reproducibility, we recommend downloading MentalBERT locally and replacing the MentalBERT path in the code with the local checkpoint path.

#### 🔐 How to Obtain a Hugging Face Access Token

To download and use gated models (e.g., mental/mental-bert-base-uncased), you need a Hugging Face account and a valid **access token**.

Please follow the steps below:

**Step 1: Create a Hugging Face Account**

If you do not already have one, create an account at:
  - https://huggingface.co/join

**Step 2: Generate an Access Token**

1. Log in to your Hugging Face account.
2. Go to **Settings → Access Tokens**.
3. Click **“Create new token”**.
4. Choose **Read** permission (this is sufficient for downloading models).
5. Give the token a name (e.g., `mental-bert-access`).
6. Click **Create token** and copy the token.

⚠️ Keep your token private. Do not share it or commit it to public repositories.

**Step 3: Log In Programmatically**

Before loading the model, log in using the Hugging Face Hub API:

```python
from huggingface_hub import login

login()  # Paste your access token when prompted
```
This step is required when running code in online environments such as **Google Colab** or remote servers.

**Step 4: Request Access to MentalBERT**

The model `mental/mental-bert-base-uncased` is a gated repository.
You must explicitly request access on its Hugging Face model page:

- https://huggingface.co/mental/mental-bert-base-uncased

Once access is granted, you will be able to download the model using your access token.

```python
import torch
from transformers import AutoConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, login
# login()  # Required when running in an online environment (e.g., Google Colab)
# from model import BERTDiseaseClassifier

repo_id = "shallowblueQAQ/PsySym-model"
subfolder = "relevance_model"
# subfolder = "status_model"

# 1. Load Config & Tokenizer
config = AutoConfig.from_pretrained(repo_id, subfolder=subfolder)
tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=subfolder)

# 2. Initialize Model Architecture
# model = BERTDiseaseClassifier(model_type="mental/mental-bert-base-uncased", num_symps=len(config.id2label))
# Replace `/path/to/mental-bert-base-uncased` with the actual local path where MentalBERT is stored.
model = BERTDiseaseClassifier(model_type="/path/to/mental-bert-base-uncased", num_symps=len(config.id2label))

# 3. Load Weights
weights_path = hf_hub_download(repo_id=repo_id, subfolder=subfolder, filename="pytorch_model.bin")
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
model.eval()

# 4. Inference
text = "I had a headache yesterday." if subfolder == "relevance_model" else "Does taking away distractions from some one that has ADD distract the person more or less?"
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)

with torch.no_grad():
    logits = model(**inputs)
    probs = torch.sigmoid(logits)

# Display Predictions (Multi-label)
threshold = 0.5
for i, prob in enumerate(probs[0]):
    if prob > threshold:
        print(f"Detected: {config.id2label[i]} ({prob:.4f})")
```

**B. Loading Disease Detection Models**

Note: The disease model takes symptom feature vectors as input (Shape: [Batch, Seq_Len, 38]), not raw text.
```python
import torch
from transformers import AutoConfig
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

# 1. Define the Model Architecture (Must match model_hf_disease.py)
# (Copy the KMaxMeanCNN class definition from the "Define Model Architectures" section above)
# model = KMaxMeanCNN(config) ... 

# 2. Configuration
repo_id = "shallowblueQAQ/PsySym-model"
disease_name = "depression"  # Options: depression, anxiety, autism, adhd, schizophrenia, bipolar, ocd, ptsd, eating.
subfolder = f"disease_model/{disease_name}"

# 3. Load Config
config = DiseaseConfig.from_pretrained(repo_id, subfolder=subfolder)

# 4. Initialize Model
model = KMaxMeanCNN(config)

# 5. Load Weights
weights_path = hf_hub_download(repo_id=repo_id, subfolder=subfolder, filename="model.safetensors")
state_dict = load_file(weights_path)
model.load_state_dict(state_dict)

model.eval()

# 6. Inference Example
# Input: A sequence of symptom probabilities (from Relevance Model)
# Shape: [Batch_Size, Sequence_Length, Feature_Dim(38)]
# Example: Batch=1, User has 50 posts, each post has 38 symptom features
dummy_input = torch.randn(1, 50, 38) 

with torch.no_grad():
    # The model expects 'input_seqs'
    outputs = model(input_seqs=dummy_input)
    logits = outputs  # Shape: [1, 1]
    
    # Convert logits to probability
    prob = torch.sigmoid(logits).item()

print(f"Disease Prediction ({disease_name}): {prob:.4f}")
# Output > 0.5 implies the disease is detected
```


## ⚠️ Ethical Considerations & Limitations
1. Research Use Only: This model is intended for research purposes only. It is not a diagnostic tool and must not be used for self-diagnosis or clinical decision-making.

2. Bias & Errors: The model is trained on Reddit data and may reflect specific linguistic styles or biases present in that community. It may not generalize perfectly to other platforms or populations.

3. Data Privacy: The training data involves sensitive mental health disclosures. While the model weights do not directly expose user data, outputs should be handled with care to protect user privacy.

## Data Availability
This model was trained on **PsySym**, a subset derived from the **[SMHD (Self-reported Mental Health Diagnoses)](https://aclanthology.org/C18-1126/)** dataset.

**Due to the strict Data Usage Agreement of SMHD, we cannot publish the original dataset.** Researchers interested in the data must apply for access directly from the creators of [SMHD (Cohan et al., 2018)](https://ir.cs.georgetown.edu/resources/).

## Citation
If you use this model, please cite our paper:
```bibtex
@inproceedings{zhang2022symptom,
  title={Symptom Identification for Interpretable Detection of Multiple Mental Disorders on Social Media},
  author={Zhang, Zhiling and Chen, Siyuan and Wu, Mengyue and Zhu, Kenny},
  booktitle={Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing},
  pages={9970--9985},
  year={2022}
}
```