| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoModel, AutoConfig |
| from huggingface_hub import hf_hub_download |
| import gradio as gr |
|
|
| REPO_ID = "mental/mental-roberta-base" |
| HF_REPO_ID = "vimdhayak/neda-learned-negation" |
| ID2LABEL = {0: 'Anxious', 1: 'Depressed', 2: 'Frustrated', 3: 'Others', 4: 'Suicidal'} |
| LABEL2ID = {'Anxious': 0, 'Depressed': 1, 'Frustrated': 2, 'Others': 3, 'Suicidal': 4} |
| NUM_LABELS = 5 |
| MAX_LENGTH = 256 |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| class NegationAwareContextualModulation(nn.Module): |
| def __init__(self, H, cw=3, dp=0.1): |
| super().__init__() |
| self.neg_detector = nn.Sequential(nn.Linear(H, H//2), nn.GELU(), nn.Linear(H//2, 1)) |
| self.neg_embed = nn.Embedding(2, H) |
| self.context_attn = nn.MultiheadAttention(H, 4, dropout=dp, batch_first=True) |
| self.gate = nn.Sequential(nn.Linear(3*H, H), nn.GELU(), nn.Linear(H, 1), nn.Sigmoid()) |
| self.neg_transform = nn.Sequential(nn.Linear(H, H), nn.GELU(), nn.Dropout(dp), nn.Linear(H, H)) |
| self.neg_aux_head = nn.Linear(H, 1) |
| def forward(self, h, mask): |
| p = torch.sigmoid(self.neg_detector(h)) |
| emb = (1-p)*self.neg_embed.weight[0] + p*self.neg_embed.weight[1] |
| h = h + emb |
| ctx, _ = self.context_attn(h, h, h, key_padding_mask=(mask==0)) |
| g = self.gate(torch.cat([h, ctx, emb], -1)) |
| h = h + p * g * (self.neg_transform(h) - h) |
| return h, self.neg_aux_head(h[:,0,:]).squeeze(-1) |
|
|
| class CrossLayerGatedAttentionFusion(nn.Module): |
| def __init__(self, H, nl=4, nh=4, dp=0.1): |
| super().__init__() |
| self.lw = nn.Parameter(torch.ones(nl)/nl) |
| self.cross_attn = nn.MultiheadAttention(H, nh, dropout=dp, batch_first=True) |
| self.gate_proj = nn.Linear(2*H, H) |
| self.ln = nn.LayerNorm(H) |
| def forward(self, layers, mask): |
| w = F.softmax(self.lw, 0) |
| kv = (torch.stack(layers,0)*w.view(-1,1,1,1)).sum(0) |
| a, _ = self.cross_attn(layers[-1], kv, kv, key_padding_mask=(mask==0)) |
| g = torch.sigmoid(self.gate_proj(torch.cat([a, layers[-1]], -1))) |
| return self.ln(g*a + (1-g)*layers[-1]) |
|
|
| class MultiGranularityPooling(nn.Module): |
| def __init__(self, H, k=4, dp=0.1): |
| super().__init__() |
| self.k = k |
| self.attn_pool = nn.Sequential(nn.Linear(H,H), nn.Tanh(), nn.Linear(H,1)) |
| self.sent_attn = nn.Sequential(nn.Linear(H,H//2), nn.Tanh(), nn.Linear(H//2,1)) |
| self.fusion = nn.Sequential(nn.Linear(3*H,H), nn.LayerNorm(H), nn.GELU(), nn.Dropout(dp)) |
| def forward(self, h, mask): |
| B, L, H = h.shape |
| cls = h[:,0,:] |
| sc = self.attn_pool(h).squeeze(-1).masked_fill(mask==0, float("-inf")) |
| ar = (h * F.softmax(sc,-1).unsqueeze(-1)).sum(1) |
| cs = L // self.k |
| sents = [] |
| for i in range(self.k): |
| s, e = i*cs, (i+1)*cs if i<self.k-1 else L |
| m = mask[:,s:e].unsqueeze(-1).float() |
| sents.append((h[:,s:e,:]*m).sum(1) / m.sum(1).clamp(1e-9)) |
| sents = torch.stack(sents, 1) |
| sr = (sents * F.softmax(self.sent_attn(sents).squeeze(-1),-1).unsqueeze(-1)).sum(1) |
| return self.fusion(torch.cat([cls, ar, sr], -1)) |
|
|
| class NEDAClassifier(nn.Module): |
| def __init__(self): |
| super().__init__() |
| enc_cfg = AutoConfig.from_pretrained(REPO_ID) |
| self.encoder = AutoModel.from_pretrained(REPO_ID, config=enc_cfg) |
| H = enc_cfg.hidden_size |
| self.nacm = NegationAwareContextualModulation(H, 3, 0.1) |
| self.clgaf = CrossLayerGatedAttentionFusion(H, 4, 4, 0.1) |
| self.pool = MultiGranularityPooling(H, 4, 0.1) |
| self.classifier = nn.Sequential( |
| nn.Linear(H, 512), nn.LayerNorm(512), |
| nn.GELU(), nn.Dropout(0.1), nn.Linear(512, NUM_LABELS), |
| ) |
| self.supcon_proj = nn.Sequential(nn.Linear(H,H), nn.GELU(), nn.Linear(H,128)) |
| def forward(self, input_ids, attention_mask): |
| out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) |
| h = self.clgaf(list(out.hidden_states[-4:]), attention_mask) |
| h, _ = self.nacm(h, attention_mask) |
| return self.classifier(self.pool(h, attention_mask)) |
|
|
| |
| print("Loading tokenizer and model...") |
| tok = AutoTokenizer.from_pretrained(HF_REPO_ID) |
| model_inf = NEDAClassifier().to(device) |
| w_path = hf_hub_download(repo_id=HF_REPO_ID, filename="pytorch_model.bin") |
| model_inf.load_state_dict(torch.load(w_path, map_location=device, weights_only=True)) |
| model_inf.eval() |
| print("Ready.") |
|
|
| def predict(text: str): |
| if not text.strip(): |
| return {} |
| enc = tok(text, max_length=MAX_LENGTH, padding="max_length", truncation=True, return_tensors="pt") |
| with torch.no_grad(): |
| logits = model_inf(enc["input_ids"].to(device), enc["attention_mask"].to(device)) |
| probs = F.softmax(logits, -1).squeeze(0).cpu().tolist() |
| return {ID2LABEL[i]: round(probs[i], 4) for i in range(NUM_LABELS)} |
|
|
| examples = [ |
| ["I feel so hopeless, I can't stop thinking about ending it all"], |
| ["I have been feeling really anxious and stressed lately"], |
| ["Today was a good day, I went for a walk and felt great"], |
| ["I don't feel like doing anything anymore, nothing matters"], |
| ] |
|
|
| demo = gr.Interface( |
| fn=predict, |
| inputs=gr.Textbox(lines=4, placeholder="Enter text to classify...", label="Input Text"), |
| outputs=gr.Label(num_top_classes=NUM_LABELS, label="Mental Distress Classification"), |
| title="NEDA β Negation-aware Mental Health Text Classifier", |
| description=( |
| "**NEDA** classifies mental health Reddit posts into distress categories.\n" |
| "Backbone: `mental/mental-roberta-base` Β· Components: NACM Β· CLGAF Β· MGP Β· SupCon Β· FGM Β· EMA" |
| ), |
| examples=examples, |
| theme=gr.themes.Soft(), |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|