MentalBERT V5 — Source-Aware Multi-Task Classifier
Architecture: Dual-head MentalBERT (BertModel base + classification head + auxiliary source head)
Dataset: V5 (6 sources, 8 classes, ~88k samples)
Test Accuracy: 83.23% | F1 Macro: 0.8381
Load Pattern
import torch
import torch.nn as nn
import joblib, json
from transformers import BertModel, BertTokenizerFast
from huggingface_hub import hf_hub_download
# 1. Load BertModel base and tokenizer
base = BertModel.from_pretrained('itsLu/mentalbert-v5-source-aware')
tok = BertTokenizerFast.from_pretrained('itsLu/mentalbert-v5-source-aware')
# 2. Load config
config_path = hf_hub_download('itsLu/mentalbert-v5-source-aware', 'inference_config.json')
with open(config_path) as f:
cfg = json.load(f)
# 3. Reconstruct classification head
cls_head = nn.Linear(768, cfg['n_classes'])
head_path = hf_hub_download('itsLu/mentalbert-v5-source-aware', 'cls_head.pt')
cls_head.load_state_dict(torch.load(head_path, map_location='cpu'))
# 4. Reconstruct wrapper model
class InferenceModel(nn.Module):
def __init__(self, bert, head):
super().__init__()
self.bert = bert
self.dropout = nn.Dropout(0.1)
self.head = head
def forward(self, input_ids, attention_mask):
out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled = out.pooler_output
return self.head(self.dropout(pooled))
model = InferenceModel(base, cls_head).eval()
# 5. Inference
le_path = hf_hub_download('itsLu/mentalbert-v5-source-aware', 'label_encoder.joblib')
le = joblib.load(le_path)
def predict(text):
enc = tok(text, max_length=128, padding='max_length',
truncation=True, return_tensors='pt')
with torch.no_grad():
logits = model(enc['input_ids'], enc['attention_mask'])
probs = torch.softmax(logits, dim=1).squeeze().numpy()
idx = probs.argmax()
return le.classes_[idx], float(probs[idx])
label, prob = predict("I can't stop thinking about how worthless I am.")
print(label, f'{prob:.2%}')
Classes
- Anxiety
- Bipolar
- Depression
- Directed Aggression
- Normal
- Personality Disorder
- Stress
- Suicidal
Source Reliability Weights
| Source | Reliability |
|---|---|
| cssrs | 1.0 |
| olid | 1.0 |
| kaggle_bpd | 0.95 |
| huggingface | 0.7 |
| kaggle | 0.7 |
| swmh | 0.5 |