| # MentalBERT V5 — Source-Aware Multi-Task Classifier |
|
|
| **Architecture:** Dual-head MentalBERT (BertModel base + classification head + auxiliary source head) |
| **Dataset:** V5 (6 sources, 8 classes, ~88k samples) |
| **Test Accuracy:** 83.23% | **F1 Macro:** 0.8381 |
|
|
| ## Load Pattern |
|
|
| ```python |
| import torch |
| import torch.nn as nn |
| import joblib, json |
| from transformers import BertModel, BertTokenizerFast |
| from huggingface_hub import hf_hub_download |
| |
| # 1. Load BertModel base and tokenizer |
| base = BertModel.from_pretrained('itsLu/mentalbert-v5-source-aware') |
| tok = BertTokenizerFast.from_pretrained('itsLu/mentalbert-v5-source-aware') |
| |
| # 2. Load config |
| config_path = hf_hub_download('itsLu/mentalbert-v5-source-aware', 'inference_config.json') |
| with open(config_path) as f: |
| cfg = json.load(f) |
| |
| # 3. Reconstruct classification head |
| cls_head = nn.Linear(768, cfg['n_classes']) |
| head_path = hf_hub_download('itsLu/mentalbert-v5-source-aware', 'cls_head.pt') |
| cls_head.load_state_dict(torch.load(head_path, map_location='cpu')) |
| |
| # 4. Reconstruct wrapper model |
| class InferenceModel(nn.Module): |
| def __init__(self, bert, head): |
| super().__init__() |
| self.bert = bert |
| self.dropout = nn.Dropout(0.1) |
| self.head = head |
| def forward(self, input_ids, attention_mask): |
| out = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| pooled = out.pooler_output |
| return self.head(self.dropout(pooled)) |
| |
| model = InferenceModel(base, cls_head).eval() |
| |
| # 5. Inference |
| le_path = hf_hub_download('itsLu/mentalbert-v5-source-aware', 'label_encoder.joblib') |
| le = joblib.load(le_path) |
| |
| def predict(text): |
| enc = tok(text, max_length=128, padding='max_length', |
| truncation=True, return_tensors='pt') |
| with torch.no_grad(): |
| logits = model(enc['input_ids'], enc['attention_mask']) |
| probs = torch.softmax(logits, dim=1).squeeze().numpy() |
| idx = probs.argmax() |
| return le.classes_[idx], float(probs[idx]) |
| |
| label, prob = predict("I can't stop thinking about how worthless I am.") |
| print(label, f'{prob:.2%}') |
| ``` |
|
|
| ## Classes |
| - Anxiety |
| - Bipolar |
| - Depression |
| - Directed Aggression |
| - Normal |
| - Personality Disorder |
| - Stress |
| - Suicidal |
|
|
| ## Source Reliability Weights |
| | Source | Reliability | |
| |--------|-------------| |
| | cssrs | 1.0 | |
| | olid | 1.0 | |
| | kaggle_bpd | 0.95 | |
| | huggingface | 0.7 | |
| | kaggle | 0.7 | |
| | swmh | 0.5 | |
| |