|
|
--- |
|
|
license: cc-by-nc-4.0 |
|
|
tags: |
|
|
- mental-health |
|
|
- social-media |
|
|
- symptom-identification |
|
|
--- |
|
|
|
|
|
# Symptom Classification Model |
|
|
|
|
|
## 📖 Model Overview |
|
|
|
|
|
The system consists of two distinct components: |
|
|
1. **Symptom Relevance Model (`relevance_model`)**: A multi-label classifier that identifies 32 symptom categories from social media sentences. |
|
|
2. **Symptom Status Model (`status_model`)**: A model that determines the uncertainty status of the identified symptoms (e.g., distinguishing "I have insomnia" from "I don't have insomnia"). |
|
|
|
|
|
### Architecture |
|
|
Based on **BERT** (MentalBERT-base for status model) with a linear classification head. |
|
|
|
|
|
|
|
|
## 📂 Repository Structure |
|
|
|
|
|
This repository uses **subfolders** to store the weights for different models. |
|
|
|
|
|
| Subfolder | Task Description | Input | Output | |
|
|
| :--- | :--- | :--- | :--- | |
|
|
| `relevance_model/` | Identifies which of the 32 symptoms are present. | Text (Sentence) | Logits (Dim: 32) | |
|
|
| `status_model/` | Estimates the uncertainty of the symptom. | Text (Sentence) | Logits (Dim: 1) | |
|
|
|
|
|
|
|
|
|
|
|
## Quick Start (Copy & Run) |
|
|
|
|
|
Since these models use custom architectures, **you must define the model classes locally** before loading the weights. |
|
|
|
|
|
### 1. Installation |
|
|
|
|
|
```bash |
|
|
pip install transformers torch huggingface_hub |
|
|
``` |
|
|
|
|
|
### 2. Define Model Architectures |
|
|
|
|
|
|
|
|
```python |
|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import AutoModel, AutoConfig |
|
|
|
|
|
class BERTDiseaseClassifier(nn.Module): |
|
|
def __init__(self, model_type, num_symps) -> None: |
|
|
super().__init__() |
|
|
self.model_type = model_type |
|
|
self.num_symps = num_symps |
|
|
self.encoder = AutoModel.from_pretrained(model_type) |
|
|
self.dropout = nn.Dropout(self.encoder.config.hidden_dropout_prob) |
|
|
self.clf = nn.Linear(self.encoder.config.hidden_size, num_symps) |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs): |
|
|
outputs = self.encoder(input_ids, attention_mask, token_type_ids) |
|
|
x = outputs.last_hidden_state[:, 0, :] # [CLS] pooling |
|
|
x = self.dropout(x) |
|
|
logits = self.clf(x) |
|
|
return logits |
|
|
``` |
|
|
|
|
|
### 3. Usage Example |
|
|
|
|
|
**A. Loading Relevance & Status Models** |
|
|
|
|
|
Unlike standard BERT models, **mental/mental-bert-base-uncased** is a gated (non-public) model on Hugging Face. |
|
|
Users must log in to their Hugging Face account and obtain access permission before downloading it. |
|
|
|
|
|
For convenience and reproducibility, we recommend downloading MentalBERT locally and replacing the MentalBERT path in the code with the local checkpoint path. |
|
|
|
|
|
#### 🔐 How to Obtain a Hugging Face Access Token |
|
|
|
|
|
To download and use gated models (e.g., mental/mental-bert-base-uncased), you need a Hugging Face account and a valid **access token**. |
|
|
|
|
|
Please follow the steps below: |
|
|
|
|
|
**Step 1: Create a Hugging Face Account** |
|
|
|
|
|
If you do not already have one, create an account at: |
|
|
- https://huggingface.co/join |
|
|
|
|
|
**Step 2: Generate an Access Token** |
|
|
|
|
|
1. Log in to your Hugging Face account. |
|
|
2. Go to **Settings → Access Tokens**. |
|
|
3. Click **“Create new token”**. |
|
|
4. Choose **Read** permission (this is sufficient for downloading models). |
|
|
5. Give the token a name (e.g., `mental-bert-access`). |
|
|
6. Click **Create token** and copy the token. |
|
|
|
|
|
⚠️ Keep your token private. Do not share it or commit it to public repositories. |
|
|
|
|
|
**Step 3: Log In Programmatically** |
|
|
|
|
|
Before loading the model, log in using the Hugging Face Hub API: |
|
|
|
|
|
```python |
|
|
from huggingface_hub import login |
|
|
|
|
|
login() # Paste your access token when prompted |
|
|
``` |
|
|
This step is required when running code in online environments such as **Google Colab** or remote servers. |
|
|
|
|
|
**Step 4: Request Access to MentalBERT** |
|
|
|
|
|
The model `mental/mental-bert-base-uncased` is a gated repository. |
|
|
You must explicitly request access on its Hugging Face model page: |
|
|
|
|
|
- https://huggingface.co/mental/mental-bert-base-uncased |
|
|
|
|
|
Once access is granted, you will be able to download the model using your access token. |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoConfig, AutoTokenizer |
|
|
from huggingface_hub import hf_hub_download, login |
|
|
# login() # Required when running in an online environment (e.g., Google Colab) |
|
|
# from model import BERTDiseaseClassifier |
|
|
|
|
|
repo_id = "shallowblueQAQ/Symptom-model" |
|
|
subfolder = "relevance_model" |
|
|
# subfolder = "status_model" |
|
|
|
|
|
# 1. Load Config & Tokenizer |
|
|
config = AutoConfig.from_pretrained(repo_id, subfolder=subfolder) |
|
|
tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=subfolder) |
|
|
|
|
|
# 2. Initialize Model Architecture |
|
|
# model = BERTDiseaseClassifier(model_type="mental/mental-bert-base-uncased", num_symps=len(config.id2label)) # mental-bert for status model |
|
|
# model = BERTDiseaseClassifier(model_type="bert-large-uncased", num_symps=len(config.id2label)) # bert-large for relevance model |
|
|
# Replace `/path/to/bert-large-uncased` with the actual local path where BERT-Large is stored. |
|
|
model = BERTDiseaseClassifier(model_type="/path/to/bert-large-uncased", num_symps=len(config.id2label)) |
|
|
|
|
|
# 3. Load Weights |
|
|
weights_path = hf_hub_download(repo_id=repo_id, subfolder=subfolder, filename="pytorch_model.bin") |
|
|
model.load_state_dict(torch.load(weights_path, map_location="cpu")) |
|
|
model.eval() |
|
|
|
|
|
# 4. Inference |
|
|
text = "I had a headache yesterday." if subfolder == "relevance_model" else "Does taking away distractions from some one that has ADD distract the person more or less?" |
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs) |
|
|
probs = torch.sigmoid(logits) |
|
|
|
|
|
# Display Predictions (Multi-label) |
|
|
threshold = 0.5 |
|
|
for i, prob in enumerate(probs[0]): |
|
|
if prob > threshold: |
|
|
print(f"Detected: {config.id2label[i]} ({prob:.4f})") |
|
|
``` |
|
|
|
|
|
## Performance |
|
|
| Disease | AUC(%) | |
|
|
| :--- | :--- | |
|
|
| irritability | 0.988647 | |
|
|
| anxious_mood | 0.994557 | |
|
|
| autonomic_symptoms | 0.998247 | |
|
|
| fatigue | 0.993981 | |
|
|
| depressed_mood | 0.983409 | |
|
|
| gastrointestinal_symptoms | 0.999107 | |
|
|
| genitourinary_sexual_symptoms | 0.998653 | |
|
|
| hyperactivity_agitation | 0.988879 | |
|
|
| impulsivity | 0.999138 | |
|
|
| inattention | 0.996561 | |
|
|
| suicidal_ideation | 0.999493 | |
|
|
| worthlessness_and_guilty | 0.989372 | |
|
|
| avoidance_of_stimuli | 0.996880 | |
|
|
| compensatory_behaviors_to_prevent_weight_gain | 0.999132 | |
|
|
| compulsions | 0.993975 | |
|
|
| diminished_emotional_expression | 0.996088 | |
|
|
| risky_behaviors | 0.996795 | |
|
|
| mood_and_energy_shift | 0.998956 | |
|
|
| fear_of_gaining_weight | 0.998365 | |
|
|
| fears_of_being_negatively_evaluated | 0.998446 | |
|
|
| flight_of_ideas | 0.997800 | |
|
|
| intrusion | 0.998752 | |
|
|
| loss_of_interest_or_motivation | 0.996038 | |
|
|
| talkativeness | 0.999322 | |
|
|
| obsession | 0.994734 | |
|
|
| panic_fear | 0.997528 | |
|
|
| memory_impairment | 0.998939 | |
|
|
| sleep_disturbance | 0.998297 | |
|
|
| musculoskeletal_symptoms | 0.998672 | |
|
|
| derealization&dissociation | 0.996076 | |
|
|
| sensory_symptoms | 0.998538 | |
|
|
| weight_and_appetite_change | 0.998482 | |
|
|
|
|
|
## ⚠️ Ethical Considerations & Limitations |
|
|
1. Research Use Only: This model is intended for research purposes only. It is not a diagnostic tool and must not be used for self-diagnosis or clinical decision-making. |
|
|
|
|
|
2. Bias & Errors: The model is trained on Reddit data and may reflect specific linguistic styles or biases present in that community. It may not generalize perfectly to other platforms or populations. |
|
|
|
|
|
3. Data Privacy: The training data involves sensitive mental health disclosures. While the model weights do not directly expose user data, outputs should be handled with care to protect user privacy. |
|
|
|