shallowblueQAQ commited on
Commit
fbf6f1c
·
verified ·
1 Parent(s): eb9e779

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +200 -3
README.md CHANGED
@@ -1,3 +1,200 @@
1
- ---
2
- license: cc-by-nc-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ tags:
4
+ - mental-health
5
+ - social-media
6
+ - symptom-identification
7
+ ---
8
+
9
+ # Symptom Classification Model
10
+
11
+ ## 📖 Model Overview
12
+
13
+ The system consists of two distinct components:
14
+ 1. **Symptom Relevance Model (`relevance_model`)**: A multi-label classifier that identifies 32 symptom categories from social media sentences.
15
+ 2. **Symptom Status Model (`status_model`)**: A model that determines the uncertainty status of the identified symptoms (e.g., distinguishing "I have insomnia" from "I don't have insomnia").
16
+
17
+ ### Architecture
18
+ Based on **BERT** (MentalBERT-base for status model) with a linear classification head.
19
+
20
+
21
+ ## 📂 Repository Structure
22
+
23
+ This repository uses **subfolders** to store the weights for different models.
24
+
25
+ | Subfolder | Task Description | Input | Output |
26
+ | :--- | :--- | :--- | :--- |
27
+ | `relevance_model/` | Identifies which of the 32 symptoms are present. | Text (Sentence) | Logits (Dim: 32) |
28
+ | `status_model/` | Estimates the uncertainty of the symptom. | Text (Sentence) | Logits (Dim: 1) |
29
+
30
+
31
+
32
+ ## Quick Start (Copy & Run)
33
+
34
+ Since these models use custom architectures, **you must define the model classes locally** before loading the weights.
35
+
36
+ ### 1. Installation
37
+
38
+ ```bash
39
+ pip install transformers torch huggingface_hub
40
+ ```
41
+
42
+ ### 2. Define Model Architectures
43
+
44
+
45
+ ```python
46
+ import torch
47
+ from torch import nn
48
+ from transformers import AutoModel, AutoConfig
49
+
50
+ class BERTDiseaseClassifier(nn.Module):
51
+ def __init__(self, model_type, num_symps) -> None:
52
+ super().__init__()
53
+ self.model_type = model_type
54
+ self.num_symps = num_symps
55
+ self.encoder = AutoModel.from_pretrained(model_type)
56
+ self.dropout = nn.Dropout(self.encoder.config.hidden_dropout_prob)
57
+ self.clf = nn.Linear(self.encoder.config.hidden_size, num_symps)
58
+
59
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs):
60
+ outputs = self.encoder(input_ids, attention_mask, token_type_ids)
61
+ x = outputs.last_hidden_state[:, 0, :] # [CLS] pooling
62
+ x = self.dropout(x)
63
+ logits = self.clf(x)
64
+ return logits
65
+ ```
66
+
67
+ ### 3. Usage Example
68
+
69
+ **A. Loading Relevance & Status Models**
70
+
71
+ Unlike standard BERT models, **mental/mental-bert-base-uncased** is a gated (non-public) model on Hugging Face.
72
+ Users must log in to their Hugging Face account and obtain access permission before downloading it.
73
+
74
+ For convenience and reproducibility, we recommend downloading MentalBERT locally and replacing the MentalBERT path in the code with the local checkpoint path.
75
+
76
+ #### 🔐 How to Obtain a Hugging Face Access Token
77
+
78
+ To download and use gated models (e.g., mental/mental-bert-base-uncased), you need a Hugging Face account and a valid **access token**.
79
+
80
+ Please follow the steps below:
81
+
82
+ **Step 1: Create a Hugging Face Account**
83
+
84
+ If you do not already have one, create an account at:
85
+ - https://huggingface.co/join
86
+
87
+ **Step 2: Generate an Access Token**
88
+
89
+ 1. Log in to your Hugging Face account.
90
+ 2. Go to **Settings → Access Tokens**.
91
+ 3. Click **“Create new token”**.
92
+ 4. Choose **Read** permission (this is sufficient for downloading models).
93
+ 5. Give the token a name (e.g., `mental-bert-access`).
94
+ 6. Click **Create token** and copy the token.
95
+
96
+ ⚠️ Keep your token private. Do not share it or commit it to public repositories.
97
+
98
+ **Step 3: Log In Programmatically**
99
+
100
+ Before loading the model, log in using the Hugging Face Hub API:
101
+
102
+ ```python
103
+ from huggingface_hub import login
104
+
105
+ login() # Paste your access token when prompted
106
+ ```
107
+ This step is required when running code in online environments such as **Google Colab** or remote servers.
108
+
109
+ **Step 4: Request Access to MentalBERT**
110
+
111
+ The model `mental/mental-bert-base-uncased` is a gated repository.
112
+ You must explicitly request access on its Hugging Face model page:
113
+
114
+ - https://huggingface.co/mental/mental-bert-base-uncased
115
+
116
+ Once access is granted, you will be able to download the model using your access token.
117
+
118
+ ```python
119
+ import torch
120
+ from transformers import AutoConfig, AutoTokenizer
121
+ from huggingface_hub import hf_hub_download, login
122
+ # login() # Required when running in an online environment (e.g., Google Colab)
123
+ # from model import BERTDiseaseClassifier
124
+
125
+ repo_id = "shallowblueQAQ/Symptom-model"
126
+ subfolder = "relevance_model"
127
+ # subfolder = "status_model"
128
+
129
+ # 1. Load Config & Tokenizer
130
+ config = AutoConfig.from_pretrained(repo_id, subfolder=subfolder)
131
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=subfolder)
132
+
133
+ # 2. Initialize Model Architecture
134
+ # model = BERTDiseaseClassifier(model_type="mental/mental-bert-base-uncased", num_symps=len(config.id2label)) # mental-bert for status model
135
+ # model = BERTDiseaseClassifier(model_type="bert-large-uncased", num_symps=len(config.id2label)) # bert-large for relevance model
136
+ # Replace `/path/to/bert-large-uncased` with the actual local path where BERT-Large is stored.
137
+ model = BERTDiseaseClassifier(model_type="/path/to/bert-large-uncased", num_symps=len(config.id2label))
138
+
139
+ # 3. Load Weights
140
+ weights_path = hf_hub_download(repo_id=repo_id, subfolder=subfolder, filename="pytorch_model.bin")
141
+ model.load_state_dict(torch.load(weights_path, map_location="cpu"))
142
+ model.eval()
143
+
144
+ # 4. Inference
145
+ text = "I had a headache yesterday." if subfolder == "relevance_model" else "Does taking away distractions from some one that has ADD distract the person more or less?"
146
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
147
+
148
+ with torch.no_grad():
149
+ logits = model(**inputs)
150
+ probs = torch.sigmoid(logits)
151
+
152
+ # Display Predictions (Multi-label)
153
+ threshold = 0.5
154
+ for i, prob in enumerate(probs[0]):
155
+ if prob > threshold:
156
+ print(f"Detected: {config.id2label[i]} ({prob:.4f})")
157
+ ```
158
+
159
+ ## Performance
160
+ | Disease | AUC(%) |
161
+ | :--- | :--- |
162
+ | Anger_Irritability | 0.988647 |
163
+ | Anxious_Mood | 0.994557 |
164
+ | Autonomic_Respiratory_Cardiovascular_symptoms | 0.998247 |
165
+ | Decreased_energy_tiredness_fatigue | 0.993981 |
166
+ | Depressed_Mood | 0.983409 |
167
+ | Gastrointestinal_symptoms | 0.999107 |
168
+ | Genitourinary_sexual_symptoms | 0.998653 |
169
+ | Hyperactivity_agitation | 0.988879 |
170
+ | Impulsivity | 0.999138 |
171
+ | Inattention | 0.996561 |
172
+ | Suicidal_ideas | 0.999493 |
173
+ | Worthlessness_and_guilty | 0.989372 |
174
+ | avoidance_of_stimuli | 0.996880 |
175
+ | compensatory_behaviors_to_prevent_weight_gain | 0.999132 |
176
+ | compulsions | 0.993975 |
177
+ | diminished_emotional_expression | 0.996088 |
178
+ | risky_behaviors | 0.996795 |
179
+ | drastical_shift_in_mood_and_energy | 0.998956 |
180
+ | fear_of_gaining_weight | 0.998365 |
181
+ | fears_of_being_negatively_evaluated | 0.998446 |
182
+ | flight_of_ideas | 0.997800 |
183
+ | intrusion_symptoms | 0.998752 |
184
+ | loss_of_interest_or_motivation | 0.996038 |
185
+ | more_talktive | 0.999322 |
186
+ | obsession | 0.994734 |
187
+ | panic_fear | 0.997528 |
188
+ | poor_memory | 0.998939 |
189
+ | sleep_disturbance | 0.998297 |
190
+ | somatic_muscle | 0.998672 |
191
+ | Derealization&dissociation | 0.996076 |
192
+ | somatic_symptoms_sensory | 0.998538 |
193
+ | weight_and_appetite_change | 0.998482 |
194
+
195
+ ## ⚠️ Ethical Considerations & Limitations
196
+ 1. Research Use Only: This model is intended for research purposes only. It is not a diagnostic tool and must not be used for self-diagnosis or clinical decision-making.
197
+
198
+ 2. Bias & Errors: The model is trained on Reddit data and may reflect specific linguistic styles or biases present in that community. It may not generalize perfectly to other platforms or populations.
199
+
200
+ 3. Data Privacy: The training data involves sensitive mental health disclosures. While the model weights do not directly expose user data, outputs should be handled with care to protect user privacy.