itsLu's picture
Upload README.md
7cec964 verified

MentalBERT V5 — Source-Aware Multi-Task Classifier

Architecture: Dual-head MentalBERT (BertModel base + classification head + auxiliary source head)
Dataset: V5 (6 sources, 8 classes, ~88k samples)
Test Accuracy: 83.23% | F1 Macro: 0.8381

Load Pattern

import torch
import torch.nn as nn
import joblib, json
from transformers import BertModel, BertTokenizerFast
from huggingface_hub import hf_hub_download

# 1. Load BertModel base and tokenizer
base = BertModel.from_pretrained('itsLu/mentalbert-v5-source-aware')
tok  = BertTokenizerFast.from_pretrained('itsLu/mentalbert-v5-source-aware')

# 2. Load config
config_path = hf_hub_download('itsLu/mentalbert-v5-source-aware', 'inference_config.json')
with open(config_path) as f:
    cfg = json.load(f)

# 3. Reconstruct classification head
cls_head = nn.Linear(768, cfg['n_classes'])
head_path = hf_hub_download('itsLu/mentalbert-v5-source-aware', 'cls_head.pt')
cls_head.load_state_dict(torch.load(head_path, map_location='cpu'))

# 4. Reconstruct wrapper model
class InferenceModel(nn.Module):
    def __init__(self, bert, head):
        super().__init__()
        self.bert    = bert
        self.dropout = nn.Dropout(0.1)
        self.head    = head
    def forward(self, input_ids, attention_mask):
        out    = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = out.pooler_output
        return self.head(self.dropout(pooled))

model = InferenceModel(base, cls_head).eval()

# 5. Inference
le_path = hf_hub_download('itsLu/mentalbert-v5-source-aware', 'label_encoder.joblib')
le = joblib.load(le_path)

def predict(text):
    enc   = tok(text, max_length=128, padding='max_length',
                truncation=True, return_tensors='pt')
    with torch.no_grad():
        logits = model(enc['input_ids'], enc['attention_mask'])
    probs = torch.softmax(logits, dim=1).squeeze().numpy()
    idx   = probs.argmax()
    return le.classes_[idx], float(probs[idx])

label, prob = predict("I can't stop thinking about how worthless I am.")
print(label, f'{prob:.2%}')

Classes

  • Anxiety
  • Bipolar
  • Depression
  • Directed Aggression
  • Normal
  • Personality Disorder
  • Stress
  • Suicidal

Source Reliability Weights

Source Reliability
cssrs 1.0
olid 1.0
kaggle_bpd 0.95
huggingface 0.7
kaggle 0.7
swmh 0.5