YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

MentalBERT V5 โ€” Source-Aware Multi-Task Classifier

Architecture: Dual-head MentalBERT (BertModel base + classification head + auxiliary source head)
Dataset: V5 (6 sources, 8 classes, ~88k samples)
Test Accuracy: 83.23% | F1 Macro: 0.8381

Load Pattern

import torch
import torch.nn as nn
import joblib, json
from transformers import BertModel, BertTokenizerFast
from huggingface_hub import hf_hub_download

# 1. Load BertModel base and tokenizer
base = BertModel.from_pretrained('itsLu/mentalbert-v5-source-aware')
tok  = BertTokenizerFast.from_pretrained('itsLu/mentalbert-v5-source-aware')

# 2. Load config
config_path = hf_hub_download('itsLu/mentalbert-v5-source-aware', 'inference_config.json')
with open(config_path) as f:
    cfg = json.load(f)

# 3. Reconstruct classification head
cls_head = nn.Linear(768, cfg['n_classes'])
head_path = hf_hub_download('itsLu/mentalbert-v5-source-aware', 'cls_head.pt')
cls_head.load_state_dict(torch.load(head_path, map_location='cpu'))

# 4. Reconstruct wrapper model
class InferenceModel(nn.Module):
    def __init__(self, bert, head):
        super().__init__()
        self.bert    = bert
        self.dropout = nn.Dropout(0.1)
        self.head    = head
    def forward(self, input_ids, attention_mask):
        out    = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = out.pooler_output
        return self.head(self.dropout(pooled))

model = InferenceModel(base, cls_head).eval()

# 5. Inference
le_path = hf_hub_download('itsLu/mentalbert-v5-source-aware', 'label_encoder.joblib')
le = joblib.load(le_path)

def predict(text):
    enc   = tok(text, max_length=128, padding='max_length',
                truncation=True, return_tensors='pt')
    with torch.no_grad():
        logits = model(enc['input_ids'], enc['attention_mask'])
    probs = torch.softmax(logits, dim=1).squeeze().numpy()
    idx   = probs.argmax()
    return le.classes_[idx], float(probs[idx])

label, prob = predict("I can't stop thinking about how worthless I am.")
print(label, f'{prob:.2%}')

Classes

  • Anxiety
  • Bipolar
  • Depression
  • Directed Aggression
  • Normal
  • Personality Disorder
  • Stress
  • Suicidal

Source Reliability Weights

Source Reliability
cssrs 1.0
olid 1.0
kaggle_bpd 0.95
huggingface 0.7
kaggle 0.7
swmh 0.5
Downloads last month
78
Safetensors
Model size
0.1B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for itsLu/mentalbert-v5-source-aware

Finetunes
1 model