|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoModel |
|
|
|
|
|
IGNORE_IDX = -100 |
|
|
|
|
|
class XLMRMultiHead(nn.Module): |
|
|
def __init__(self, base="xlm-roberta-base", n_intent=0, n_ner=0, dropout=0.1): |
|
|
super().__init__() |
|
|
self.enc = AutoModel.from_pretrained(base) |
|
|
h = self.enc.config.hidden_size |
|
|
self.drop = nn.Dropout(dropout) |
|
|
self.intent = nn.Linear(h, n_intent) |
|
|
self.ner = nn.Linear(h, n_ner) |
|
|
self.ce_int = nn.CrossEntropyLoss() |
|
|
self.ce_tok = nn.CrossEntropyLoss(ignore_index=IGNORE_IDX) |
|
|
|
|
|
def forward(self, input_ids, attention_mask, labels_intent=None, labels_ner=None): |
|
|
out = self.enc(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) |
|
|
cls = self.drop(out.last_hidden_state[:,0]) |
|
|
seq = self.drop(out.last_hidden_state) |
|
|
|
|
|
li = self.intent(cls) |
|
|
ln = self.ner(seq) |
|
|
|
|
|
loss=None |
|
|
if labels_intent is not None and labels_ner is not None: |
|
|
l_i = self.ce_int(li, labels_intent) |
|
|
l_n = self.ce_tok(ln.reshape(-1, ln.size(-1)), labels_ner.reshape(-1)) |
|
|
loss = 1.0*l_i + 0.8*l_n |
|
|
return {"loss": loss, "logits_intent": li, "logits_ner": ln} |
|
|
|