vimdhayak's picture
Upload NEDA model + Gradio app.py + requirements.txt
af3b6da verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, AutoConfig
from huggingface_hub import hf_hub_download
import gradio as gr
REPO_ID = "mental/mental-roberta-base" # backbone
HF_REPO_ID = "vimdhayak/neda-learned-negation" # ← same as upload repo
ID2LABEL = {0: 'Anxious', 1: 'Depressed', 2: 'Frustrated', 3: 'Others', 4: 'Suicidal'}
LABEL2ID = {'Anxious': 0, 'Depressed': 1, 'Frustrated': 2, 'Others': 3, 'Suicidal': 4}
NUM_LABELS = 5
MAX_LENGTH = 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ── Architecture (must match training) ───────────────────────
class NegationAwareContextualModulation(nn.Module):
def __init__(self, H, cw=3, dp=0.1):
super().__init__()
self.neg_detector = nn.Sequential(nn.Linear(H, H//2), nn.GELU(), nn.Linear(H//2, 1))
self.neg_embed = nn.Embedding(2, H)
self.context_attn = nn.MultiheadAttention(H, 4, dropout=dp, batch_first=True)
self.gate = nn.Sequential(nn.Linear(3*H, H), nn.GELU(), nn.Linear(H, 1), nn.Sigmoid())
self.neg_transform = nn.Sequential(nn.Linear(H, H), nn.GELU(), nn.Dropout(dp), nn.Linear(H, H))
self.neg_aux_head = nn.Linear(H, 1)
def forward(self, h, mask):
p = torch.sigmoid(self.neg_detector(h))
emb = (1-p)*self.neg_embed.weight[0] + p*self.neg_embed.weight[1]
h = h + emb
ctx, _ = self.context_attn(h, h, h, key_padding_mask=(mask==0))
g = self.gate(torch.cat([h, ctx, emb], -1))
h = h + p * g * (self.neg_transform(h) - h)
return h, self.neg_aux_head(h[:,0,:]).squeeze(-1)
class CrossLayerGatedAttentionFusion(nn.Module):
def __init__(self, H, nl=4, nh=4, dp=0.1):
super().__init__()
self.lw = nn.Parameter(torch.ones(nl)/nl)
self.cross_attn = nn.MultiheadAttention(H, nh, dropout=dp, batch_first=True)
self.gate_proj = nn.Linear(2*H, H)
self.ln = nn.LayerNorm(H)
def forward(self, layers, mask):
w = F.softmax(self.lw, 0)
kv = (torch.stack(layers,0)*w.view(-1,1,1,1)).sum(0)
a, _ = self.cross_attn(layers[-1], kv, kv, key_padding_mask=(mask==0))
g = torch.sigmoid(self.gate_proj(torch.cat([a, layers[-1]], -1)))
return self.ln(g*a + (1-g)*layers[-1])
class MultiGranularityPooling(nn.Module):
def __init__(self, H, k=4, dp=0.1):
super().__init__()
self.k = k
self.attn_pool = nn.Sequential(nn.Linear(H,H), nn.Tanh(), nn.Linear(H,1))
self.sent_attn = nn.Sequential(nn.Linear(H,H//2), nn.Tanh(), nn.Linear(H//2,1))
self.fusion = nn.Sequential(nn.Linear(3*H,H), nn.LayerNorm(H), nn.GELU(), nn.Dropout(dp))
def forward(self, h, mask):
B, L, H = h.shape
cls = h[:,0,:]
sc = self.attn_pool(h).squeeze(-1).masked_fill(mask==0, float("-inf"))
ar = (h * F.softmax(sc,-1).unsqueeze(-1)).sum(1)
cs = L // self.k
sents = []
for i in range(self.k):
s, e = i*cs, (i+1)*cs if i<self.k-1 else L
m = mask[:,s:e].unsqueeze(-1).float()
sents.append((h[:,s:e,:]*m).sum(1) / m.sum(1).clamp(1e-9))
sents = torch.stack(sents, 1)
sr = (sents * F.softmax(self.sent_attn(sents).squeeze(-1),-1).unsqueeze(-1)).sum(1)
return self.fusion(torch.cat([cls, ar, sr], -1))
class NEDAClassifier(nn.Module):
def __init__(self):
super().__init__()
enc_cfg = AutoConfig.from_pretrained(REPO_ID)
self.encoder = AutoModel.from_pretrained(REPO_ID, config=enc_cfg)
H = enc_cfg.hidden_size
self.nacm = NegationAwareContextualModulation(H, 3, 0.1)
self.clgaf = CrossLayerGatedAttentionFusion(H, 4, 4, 0.1)
self.pool = MultiGranularityPooling(H, 4, 0.1)
self.classifier = nn.Sequential(
nn.Linear(H, 512), nn.LayerNorm(512),
nn.GELU(), nn.Dropout(0.1), nn.Linear(512, NUM_LABELS),
)
self.supcon_proj = nn.Sequential(nn.Linear(H,H), nn.GELU(), nn.Linear(H,128))
def forward(self, input_ids, attention_mask):
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
h = self.clgaf(list(out.hidden_states[-4:]), attention_mask)
h, _ = self.nacm(h, attention_mask)
return self.classifier(self.pool(h, attention_mask))
# ── Load at startup ───────────────────────────────────────────
print("Loading tokenizer and model...")
tok = AutoTokenizer.from_pretrained(HF_REPO_ID)
model_inf = NEDAClassifier().to(device)
w_path = hf_hub_download(repo_id=HF_REPO_ID, filename="pytorch_model.bin")
model_inf.load_state_dict(torch.load(w_path, map_location=device, weights_only=True))
model_inf.eval()
print("Ready.")
def predict(text: str):
if not text.strip():
return {}
enc = tok(text, max_length=MAX_LENGTH, padding="max_length", truncation=True, return_tensors="pt")
with torch.no_grad():
logits = model_inf(enc["input_ids"].to(device), enc["attention_mask"].to(device))
probs = F.softmax(logits, -1).squeeze(0).cpu().tolist()
return {ID2LABEL[i]: round(probs[i], 4) for i in range(NUM_LABELS)}
examples = [
["I feel so hopeless, I can't stop thinking about ending it all"],
["I have been feeling really anxious and stressed lately"],
["Today was a good day, I went for a walk and felt great"],
["I don't feel like doing anything anymore, nothing matters"],
]
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=4, placeholder="Enter text to classify...", label="Input Text"),
outputs=gr.Label(num_top_classes=NUM_LABELS, label="Mental Distress Classification"),
title="NEDA β€” Negation-aware Mental Health Text Classifier",
description=(
"**NEDA** classifies mental health Reddit posts into distress categories.\n"
"Backbone: `mental/mental-roberta-base` Β· Components: NACM Β· CLGAF Β· MGP Β· SupCon Β· FGM Β· EMA"
),
examples=examples,
theme=gr.themes.Soft(),
)
if __name__ == "__main__":
demo.launch()