MentalBERT V5 β Hierarchical V2 (Cardiff DA + Two-Branch + Longformer)
Five-stage cascade for fine-grained mental-health text classification. Iteration 2 of the V5 hierarchical architecture, retrained with anti-overconfidence regularizers and per-source reliability weighting to address the Suicidal precision issues from V1.
Final Test Results (V5 stratified 70/10/20, random_state=42)
| Metric | Value |
|---|---|
| Accuracy | 83.51% |
| F1 macro | 0.8396 |
| F1 weighted | 0.8325 |
| Sui->Dep (missed crises) | 181 |
| Dep->Sui (false alarms) | 1075 |
| Total Dep<->Sui bleed | 1256 |
| Suicidal precision | 0.7064 |
Architecture
text -> Stage 0 (Cardiff RoBERTa, DA gate, full V5)
-> Stage 1A (MentalBERT, Suicidal one-vs-all, full V5 minus DA)
-> Stage 1B (MentalBERT, Normal vs Distress, full V5)
-> Stage 2 (MentalBERT, 5-class distress, full V5)
-> Stage 3 (Longformer, Dep/Sui rescorer, full V5 Dep+Sui)
V2 training improvements (vs V1)
V1 had inflated Suicidal logits β the Stage 1A threshold had to be calibrated to 0.75 (extreme) to balance precision/recall, indicating systematic overconfidence. V2 addresses this with four targeted changes to Stages 1A and 3 (the safety-critical stages):
- No
WeightedRandomSampler. V1 stacked sampler-based 50/50 batch balancing on top of asymmetric class weights β over-correction. V2 uses class weights alone. SUI_BOOSTandDEP_SUI_BOOST: 3.0 -> 2.0. Modest safety bias instead of extreme.label_smoothingon Stages 1A and 3: 0.05 -> 0.10. Direct anti-overconfidence regularizer.- Per-source reliability weighting in the cross-entropy loss. Per the v11 dataset
report, label noise concentrates in SWMH (
22% Sui error rate) and kaggle/HF (16%). Each training sample's loss is multiplied by its source weight:cssrs=1.00, olid=1.00, kaggle_bpd=0.95, kaggle=0.70, huggingface=0.70, swmh=0.50.
Calibrated thresholds: Stage 0 = 0.7, Stage 1A = 0.5, Stage 3 = 0.3.
API Usage - HF Inference Endpoints
After deploying this repo as an Inference Endpoint, your backend calls it like:
import requests
ENDPOINT_URL = "https://YOUR-ENDPOINT.endpoints.huggingface.cloud"
HF_TOKEN = "hf_..."
headers = {"Authorization": f"Bearer {HF_TOKEN}", "Content-Type": "application/json"}
response = requests.post(
ENDPOINT_URL, headers=headers,
json={"inputs": "I have been stockpiling pills, just in case I decide it is time."}
)
print(response.json())
Response Schema
{
"label": "Suicidal",
"stage": "s1a",
"confidence": 0.87,
"scores": { "Suicidal": 0.87, "Other": 0.13 }
}
When Stage 3 fires (Stage 2 said Depression but Longformer flips to Suicidal):
{
"label": "Suicidal",
"stage": "s3",
"confidence": 0.74,
"rescored": true,
"scores": { "Suicidal": 0.74, "Depression": 0.26 }
}
Batch input - pass a list under inputs, get a list back.
Field Reference
| Field | Type | Description |
|---|---|---|
label |
string | One of the 8 classes |
stage |
s0 / s1a / s1b / s2 / s3 |
Which stage decided |
confidence |
float (0-1) | Confidence in the chosen label |
scores |
dict | Per-class softmax scores at the deciding stage |
rescored |
bool (only when stage=s3) | Stage 3 Longformer overrode Stage 2's Depression call |
Class List
Anxiety, Bipolar, Depression, Directed Aggression, Normal, Personality Disorder, Stress, Suicidal.
Local Use
from huggingface_hub import snapshot_download
from pipeline import HierarchicalMentalHealthPipeline
path = snapshot_download(repo_id="<YOUR_USERNAME>/mentalbert-v5-hierarchical-longformer")
pipe = HierarchicalMentalHealthPipeline(path)
result = pipe("I haven't slept in days, everything feels pointless")
print(result)
Limitations
- Screening signal, not a clinical diagnosis. Deploy with human review for safety-critical use.
- Reddit-style English text. OOD domains (clinical notes, formal prose, non-English) will degrade.
- Sui->Dep errors remain β pair with a conservative downstream policy (e.g. confidence
threshold on
s1aands3confidence) for crisis-intervention escalation.
Training Details
- Hardware: NVIDIA T4 (Kaggle)
- Total training time: ~6-7 hours across all 5 stages
- See
config.jsonfor full hyperparameters and per-stage validation F1s.
- Downloads last month
- 104