xlmr-multihead-tr / modeling_xlmr_multihead.py
celalkartoglu's picture
Add multi-head XLM-R model (intent+NER) for TR
c8b8422 verified
import torch
import torch.nn as nn
from transformers import AutoModel
IGNORE_IDX = -100
class XLMRMultiHead(nn.Module):
def __init__(self, base="xlm-roberta-base", n_intent=0, n_ner=0, dropout=0.1):
super().__init__()
self.enc = AutoModel.from_pretrained(base)
h = self.enc.config.hidden_size
self.drop = nn.Dropout(dropout)
self.intent = nn.Linear(h, n_intent)
self.ner = nn.Linear(h, n_ner)
self.ce_int = nn.CrossEntropyLoss()
self.ce_tok = nn.CrossEntropyLoss(ignore_index=IGNORE_IDX)
def forward(self, input_ids, attention_mask, labels_intent=None, labels_ner=None):
out = self.enc(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
cls = self.drop(out.last_hidden_state[:,0])
seq = self.drop(out.last_hidden_state)
li = self.intent(cls) # [B, n_intent]
ln = self.ner(seq) # [B, T, n_ner]
loss=None
if labels_intent is not None and labels_ner is not None:
l_i = self.ce_int(li, labels_intent)
l_n = self.ce_tok(ln.reshape(-1, ln.size(-1)), labels_ner.reshape(-1))
loss = 1.0*l_i + 0.8*l_n
return {"loss": loss, "logits_intent": li, "logits_ner": ln}