MentalBERT V5 β€” Hierarchical V2 (Cardiff DA + Two-Branch + Longformer)

Five-stage cascade for fine-grained mental-health text classification. Iteration 2 of the V5 hierarchical architecture, retrained with anti-overconfidence regularizers and per-source reliability weighting to address the Suicidal precision issues from V1.

Final Test Results (V5 stratified 70/10/20, random_state=42)

Metric Value
Accuracy 83.51%
F1 macro 0.8396
F1 weighted 0.8325
Sui->Dep (missed crises) 181
Dep->Sui (false alarms) 1075
Total Dep<->Sui bleed 1256
Suicidal precision 0.7064

Architecture

text -> Stage 0 (Cardiff RoBERTa, DA gate, full V5)
        -> Stage 1A (MentalBERT, Suicidal one-vs-all, full V5 minus DA)
              -> Stage 1B (MentalBERT, Normal vs Distress, full V5)
                    -> Stage 2 (MentalBERT, 5-class distress, full V5)
                          -> Stage 3 (Longformer, Dep/Sui rescorer, full V5 Dep+Sui)

V2 training improvements (vs V1)

V1 had inflated Suicidal logits β€” the Stage 1A threshold had to be calibrated to 0.75 (extreme) to balance precision/recall, indicating systematic overconfidence. V2 addresses this with four targeted changes to Stages 1A and 3 (the safety-critical stages):

  1. No WeightedRandomSampler. V1 stacked sampler-based 50/50 batch balancing on top of asymmetric class weights β€” over-correction. V2 uses class weights alone.
  2. SUI_BOOST and DEP_SUI_BOOST: 3.0 -> 2.0. Modest safety bias instead of extreme.
  3. label_smoothing on Stages 1A and 3: 0.05 -> 0.10. Direct anti-overconfidence regularizer.
  4. Per-source reliability weighting in the cross-entropy loss. Per the v11 dataset report, label noise concentrates in SWMH (22% Sui error rate) and kaggle/HF (16%). Each training sample's loss is multiplied by its source weight: cssrs=1.00, olid=1.00, kaggle_bpd=0.95, kaggle=0.70, huggingface=0.70, swmh=0.50.

Calibrated thresholds: Stage 0 = 0.7, Stage 1A = 0.5, Stage 3 = 0.3.

API Usage - HF Inference Endpoints

After deploying this repo as an Inference Endpoint, your backend calls it like:

import requests
ENDPOINT_URL = "https://YOUR-ENDPOINT.endpoints.huggingface.cloud"
HF_TOKEN = "hf_..."
headers = {"Authorization": f"Bearer {HF_TOKEN}", "Content-Type": "application/json"}

response = requests.post(
    ENDPOINT_URL, headers=headers,
    json={"inputs": "I have been stockpiling pills, just in case I decide it is time."}
)
print(response.json())

Response Schema

{
  "label": "Suicidal",
  "stage": "s1a",
  "confidence": 0.87,
  "scores": { "Suicidal": 0.87, "Other": 0.13 }
}

When Stage 3 fires (Stage 2 said Depression but Longformer flips to Suicidal):

{
  "label": "Suicidal",
  "stage": "s3",
  "confidence": 0.74,
  "rescored": true,
  "scores": { "Suicidal": 0.74, "Depression": 0.26 }
}

Batch input - pass a list under inputs, get a list back.

Field Reference

Field Type Description
label string One of the 8 classes
stage s0 / s1a / s1b / s2 / s3 Which stage decided
confidence float (0-1) Confidence in the chosen label
scores dict Per-class softmax scores at the deciding stage
rescored bool (only when stage=s3) Stage 3 Longformer overrode Stage 2's Depression call

Class List

Anxiety, Bipolar, Depression, Directed Aggression, Normal, Personality Disorder, Stress, Suicidal.

Local Use

from huggingface_hub import snapshot_download
from pipeline import HierarchicalMentalHealthPipeline

path = snapshot_download(repo_id="<YOUR_USERNAME>/mentalbert-v5-hierarchical-longformer")
pipe = HierarchicalMentalHealthPipeline(path)
result = pipe("I haven't slept in days, everything feels pointless")
print(result)

Limitations

  • Screening signal, not a clinical diagnosis. Deploy with human review for safety-critical use.
  • Reddit-style English text. OOD domains (clinical notes, formal prose, non-English) will degrade.
  • Sui->Dep errors remain β€” pair with a conservative downstream policy (e.g. confidence threshold on s1a and s3 confidence) for crisis-intervention escalation.

Training Details

  • Hardware: NVIDIA T4 (Kaggle)
  • Total training time: ~6-7 hours across all 5 stages
  • See config.json for full hyperparameters and per-stage validation F1s.
Downloads last month
104
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support