|
|
--- |
|
|
license: cc-by-nc-4.0 |
|
|
tags: |
|
|
- mental-health |
|
|
- social-media |
|
|
- symptom-identification |
|
|
- disease-detection |
|
|
--- |
|
|
|
|
|
# 🧩 PsySym: Symptom Identification & Disease Detection System |
|
|
|
|
|
## 📖 Model Overview |
|
|
|
|
|
The relevant training code is available here: |
|
|
[](https://github.com/blmoistawinde/EMNLP22-PsySym) |
|
|
|
|
|
**What is PsySym?** |
|
|
|
|
|
**PsySym** is a comprehensive framework for interpretable mental disease detection on social media. Unlike "black-box" models that directly predict diseases from text, PsySym first identifies specific psychiatric symptoms defined in clinical manuals (DSM-5) and then uses these symptom profiles to detect mental disorders. |
|
|
|
|
|
<div align="center"> |
|
|
<img src="./assets_psysym/framework.png" width="600" alt="PsySym Framework" /> |
|
|
<em>Figure 1: Comparison between pure-text and symptom-assisted mental disease detection.</em> |
|
|
</div> |
|
|
|
|
|
This repository contains the models described in the paper **["Symptom Identification for Interpretable Detection of Multiple Mental Disorders on Social Media"](https://aclanthology.org/2022.emnlp-main.677/)** (EMNLP 2022). |
|
|
|
|
|
The system consists of three distinct components: |
|
|
1. **Symptom Relevance Model (`relevance_model`)**: A multi-label classifier that identifies 38 symptom categories from social media sentences. |
|
|
2. **Symptom Status Model (`status_model`)**: A model that determines the uncertainty status of the identified symptoms (e.g., distinguishing "I have insomnia" from "I don't have insomnia"). |
|
|
3. **Disease Detection Model (`disease_model`)**: A CNN-based model that predicts mental disorders (e.g., Depression, Anxiety) based on the symptom feature sequences extracted from user timelines. |
|
|
|
|
|
### Architecture |
|
|
* **Relevance & Status Models**: Based on **BERT** (MentalBERT-base) with a linear classification head. |
|
|
* **Disease Model**: A custom **CNN** that aggregates symptom features across a user's posting history. |
|
|
|
|
|
## 📂 Repository Structure |
|
|
|
|
|
This repository uses **subfolders** to store the weights for different models. |
|
|
|
|
|
| Subfolder | Task Description | Input | Output | |
|
|
| :--- | :--- | :--- | :--- | |
|
|
| `relevance_model/` | Identifies which of the 38 symptoms are present. | Text (Sentence) | Logits (Dim: 38) | |
|
|
| `status_model/` | Estimates the uncertainty of the symptom. | Text (Sentence) | Logits (Dim: 1) | |
|
|
| `disease_model/{disease_name}/` | Detects a specific mental disease (e.g., `depression`, `anxiety`). | Symptom Features Vector | Logits (Dim: 1) | |
|
|
|
|
|
<div align="center"> |
|
|
<img src="./assets_psysym/pipeline.png" width="600" alt="PsySym Pipeline" /> |
|
|
<em>Figure 2: The proposed symptom-assisted MDD pipeline.</em> |
|
|
</div> |
|
|
|
|
|
## 🚀 Quick Start (Copy & Run) |
|
|
|
|
|
Since these models use custom architectures, **you must define the model classes locally** before loading the weights. |
|
|
|
|
|
### 1. Installation |
|
|
|
|
|
```bash |
|
|
pip install transformers torch huggingface_hub |
|
|
``` |
|
|
|
|
|
### 2. Define Model Architectures |
|
|
|
|
|
**A. For Relevance & Status Models (BERT-based)** |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import AutoModel, AutoConfig |
|
|
|
|
|
class BERTDiseaseClassifier(nn.Module): |
|
|
def __init__(self, model_type, num_symps) -> None: |
|
|
super().__init__() |
|
|
self.model_type = model_type |
|
|
self.num_symps = num_symps |
|
|
self.encoder = AutoModel.from_pretrained(model_type) |
|
|
self.dropout = nn.Dropout(self.encoder.config.hidden_dropout_prob) |
|
|
self.clf = nn.Linear(self.encoder.config.hidden_size, num_symps) |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs): |
|
|
outputs = self.encoder(input_ids, attention_mask, token_type_ids) |
|
|
x = outputs.last_hidden_state[:, 0, :] # [CLS] pooling |
|
|
x = self.dropout(x) |
|
|
logits = self.clf(x) |
|
|
return logits |
|
|
``` |
|
|
|
|
|
**B. For Disease Detection Models (CNN-based)** |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
|
|
|
class DiseaseConfig(PretrainedConfig): |
|
|
model_type = "kmax_mean_cnn" |
|
|
def __init__(self, in_dim=38, filter_num=50, filter_sizes=(2, 3, 4, 5, 6), dropout=0.2, max_pooling_k=5, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.in_dim = in_dim |
|
|
self.filter_num = filter_num |
|
|
self.filter_sizes = filter_sizes |
|
|
self.dropout = dropout |
|
|
self.max_pooling_k = max_pooling_k |
|
|
|
|
|
def kmax_pooling(x, k): |
|
|
return x.sort(dim = 2)[0][:, :, -k:] |
|
|
|
|
|
class KMaxMeanCNN(PreTrainedModel): |
|
|
config_class = DiseaseConfig |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.filter_num = config.filter_num |
|
|
self.filter_sizes = config.filter_sizes |
|
|
self.hidden_size = len(config.filter_sizes) * config.filter_num |
|
|
self.max_pooling_k = config.max_pooling_k |
|
|
self.convs = nn.ModuleList([nn.Conv1d(config.in_dim, config.filter_num, size) for size in config.filter_sizes]) |
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
self.fc = nn.Linear(self.hidden_size, 1) |
|
|
self.post_init() |
|
|
|
|
|
def forward(self, input_seqs, **kwargs): |
|
|
# input_seqs shape: [Batch, SeqLen, InDim] |
|
|
input_seqs = input_seqs.transpose(1, 2) |
|
|
x = [F.relu(conv(input_seqs)) for conv in self.convs] |
|
|
x = [kmax_pooling(item, self.max_pooling_k).mean(2) for item in x] |
|
|
x = torch.cat(x, 1) |
|
|
x = self.dropout(x) |
|
|
logits = self.fc(x) |
|
|
return logits |
|
|
``` |
|
|
|
|
|
### 3. Usage Example |
|
|
|
|
|
**A. Loading Relevance & Status Models** |
|
|
|
|
|
Unlike standard BERT models, **mental/mental-bert-base-uncased** is a gated (non-public) model on Hugging Face. |
|
|
Users must log in to their Hugging Face account and obtain access permission before downloading it. |
|
|
|
|
|
For convenience and reproducibility, we recommend downloading MentalBERT locally and replacing the MentalBERT path in the code with the local checkpoint path. |
|
|
|
|
|
#### 🔐 How to Obtain a Hugging Face Access Token |
|
|
|
|
|
To download and use gated models (e.g., mental/mental-bert-base-uncased), you need a Hugging Face account and a valid **access token**. |
|
|
|
|
|
Please follow the steps below: |
|
|
|
|
|
**Step 1: Create a Hugging Face Account** |
|
|
|
|
|
If you do not already have one, create an account at: |
|
|
- https://huggingface.co/join |
|
|
|
|
|
**Step 2: Generate an Access Token** |
|
|
|
|
|
1. Log in to your Hugging Face account. |
|
|
2. Go to **Settings → Access Tokens**. |
|
|
3. Click **“Create new token”**. |
|
|
4. Choose **Read** permission (this is sufficient for downloading models). |
|
|
5. Give the token a name (e.g., `mental-bert-access`). |
|
|
6. Click **Create token** and copy the token. |
|
|
|
|
|
⚠️ Keep your token private. Do not share it or commit it to public repositories. |
|
|
|
|
|
**Step 3: Log In Programmatically** |
|
|
|
|
|
Before loading the model, log in using the Hugging Face Hub API: |
|
|
|
|
|
```python |
|
|
from huggingface_hub import login |
|
|
|
|
|
login() # Paste your access token when prompted |
|
|
``` |
|
|
This step is required when running code in online environments such as **Google Colab** or remote servers. |
|
|
|
|
|
**Step 4: Request Access to MentalBERT** |
|
|
|
|
|
The model `mental/mental-bert-base-uncased` is a gated repository. |
|
|
You must explicitly request access on its Hugging Face model page: |
|
|
|
|
|
- https://huggingface.co/mental/mental-bert-base-uncased |
|
|
|
|
|
Once access is granted, you will be able to download the model using your access token. |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoConfig, AutoTokenizer |
|
|
from huggingface_hub import hf_hub_download, login |
|
|
# login() # Required when running in an online environment (e.g., Google Colab) |
|
|
# from model import BERTDiseaseClassifier |
|
|
|
|
|
repo_id = "shallowblueQAQ/PsySym-model" |
|
|
subfolder = "relevance_model" |
|
|
# subfolder = "status_model" |
|
|
|
|
|
# 1. Load Config & Tokenizer |
|
|
config = AutoConfig.from_pretrained(repo_id, subfolder=subfolder) |
|
|
tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=subfolder) |
|
|
|
|
|
# 2. Initialize Model Architecture |
|
|
# model = BERTDiseaseClassifier(model_type="mental/mental-bert-base-uncased", num_symps=len(config.id2label)) |
|
|
# Replace `/path/to/mental-bert-base-uncased` with the actual local path where MentalBERT is stored. |
|
|
model = BERTDiseaseClassifier(model_type="/path/to/mental-bert-base-uncased", num_symps=len(config.id2label)) |
|
|
|
|
|
# 3. Load Weights |
|
|
weights_path = hf_hub_download(repo_id=repo_id, subfolder=subfolder, filename="pytorch_model.bin") |
|
|
model.load_state_dict(torch.load(weights_path, map_location="cpu")) |
|
|
model.eval() |
|
|
|
|
|
# 4. Inference |
|
|
text = "I had a headache yesterday." if subfolder == "relevance_model" else "Does taking away distractions from some one that has ADD distract the person more or less?" |
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs) |
|
|
probs = torch.sigmoid(logits) |
|
|
|
|
|
# Display Predictions (Multi-label) |
|
|
threshold = 0.5 |
|
|
for i, prob in enumerate(probs[0]): |
|
|
if prob > threshold: |
|
|
print(f"Detected: {config.id2label[i]} ({prob:.4f})") |
|
|
``` |
|
|
|
|
|
**B. Loading Disease Detection Models** |
|
|
|
|
|
Note: The disease model takes symptom feature vectors as input (Shape: [Batch, Seq_Len, 38]), not raw text. |
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoConfig |
|
|
from huggingface_hub import hf_hub_download |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
# 1. Define the Model Architecture (Must match model_hf_disease.py) |
|
|
# (Copy the KMaxMeanCNN class definition from the "Define Model Architectures" section above) |
|
|
# model = KMaxMeanCNN(config) ... |
|
|
|
|
|
# 2. Configuration |
|
|
repo_id = "shallowblueQAQ/PsySym-model" |
|
|
disease_name = "depression" # Options: depression, anxiety, autism, adhd, schizophrenia, bipolar, ocd, ptsd, eating. |
|
|
subfolder = f"disease_model/{disease_name}" |
|
|
|
|
|
# 3. Load Config |
|
|
config = DiseaseConfig.from_pretrained(repo_id, subfolder=subfolder) |
|
|
|
|
|
# 4. Initialize Model |
|
|
model = KMaxMeanCNN(config) |
|
|
|
|
|
# 5. Load Weights |
|
|
weights_path = hf_hub_download(repo_id=repo_id, subfolder=subfolder, filename="model.safetensors") |
|
|
state_dict = load_file(weights_path) |
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
# 6. Inference Example |
|
|
# Input: A sequence of symptom probabilities (from Relevance Model) |
|
|
# Shape: [Batch_Size, Sequence_Length, Feature_Dim(38)] |
|
|
# Example: Batch=1, User has 50 posts, each post has 38 symptom features |
|
|
dummy_input = torch.randn(1, 50, 38) |
|
|
|
|
|
with torch.no_grad(): |
|
|
# The model expects 'input_seqs' |
|
|
outputs = model(input_seqs=dummy_input) |
|
|
logits = outputs # Shape: [1, 1] |
|
|
|
|
|
# Convert logits to probability |
|
|
prob = torch.sigmoid(logits).item() |
|
|
|
|
|
print(f"Disease Prediction ({disease_name}): {prob:.4f}") |
|
|
# Output > 0.5 implies the disease is detected |
|
|
``` |
|
|
|
|
|
|
|
|
## ⚠️ Ethical Considerations & Limitations |
|
|
1. Research Use Only: This model is intended for research purposes only. It is not a diagnostic tool and must not be used for self-diagnosis or clinical decision-making. |
|
|
|
|
|
2. Bias & Errors: The model is trained on Reddit data and may reflect specific linguistic styles or biases present in that community. It may not generalize perfectly to other platforms or populations. |
|
|
|
|
|
3. Data Privacy: The training data involves sensitive mental health disclosures. While the model weights do not directly expose user data, outputs should be handled with care to protect user privacy. |
|
|
|
|
|
## Data Availability |
|
|
This model was trained on **PsySym**, a subset derived from the **[SMHD (Self-reported Mental Health Diagnoses)](https://aclanthology.org/C18-1126/)** dataset. |
|
|
|
|
|
**Due to the strict Data Usage Agreement of SMHD, we cannot publish the original dataset.** Researchers interested in the data must apply for access directly from the creators of [SMHD (Cohan et al., 2018)](https://ir.cs.georgetown.edu/resources/). |
|
|
|
|
|
## Citation |
|
|
If you use this model, please cite our paper: |
|
|
```bibtex |
|
|
@inproceedings{zhang2022symptom, |
|
|
title={Symptom Identification for Interpretable Detection of Multiple Mental Disorders on Social Media}, |
|
|
author={Zhang, Zhiling and Chen, Siyuan and Wu, Mengyue and Zhu, Kenny}, |
|
|
booktitle={Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing}, |
|
|
pages={9970--9985}, |
|
|
year={2022} |
|
|
} |
|
|
``` |