import torch import torch.nn as nn from transformers import AutoModel IGNORE_IDX = -100 class XLMRMultiHead(nn.Module): def __init__(self, base="xlm-roberta-base", n_intent=0, n_ner=0, dropout=0.1): super().__init__() self.enc = AutoModel.from_pretrained(base) h = self.enc.config.hidden_size self.drop = nn.Dropout(dropout) self.intent = nn.Linear(h, n_intent) self.ner = nn.Linear(h, n_ner) self.ce_int = nn.CrossEntropyLoss() self.ce_tok = nn.CrossEntropyLoss(ignore_index=IGNORE_IDX) def forward(self, input_ids, attention_mask, labels_intent=None, labels_ner=None): out = self.enc(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) cls = self.drop(out.last_hidden_state[:,0]) seq = self.drop(out.last_hidden_state) li = self.intent(cls) # [B, n_intent] ln = self.ner(seq) # [B, T, n_ner] loss=None if labels_intent is not None and labels_ner is not None: l_i = self.ce_int(li, labels_intent) l_n = self.ce_tok(ln.reshape(-1, ln.size(-1)), labels_ner.reshape(-1)) loss = 1.0*l_i + 0.8*l_n return {"loss": loss, "logits_intent": li, "logits_ner": ln}