| import torch |
| import torch.nn as nn |
| from transformers import AutoModel |
| from torchcrf import CRF |
|
|
| class MultiIntentClassifier(nn.Module): |
| def __init__(self, input_dim, num_intent_labels, dropout_rate=0.): |
| super(MultiIntentClassifier, self).__init__() |
| self.dropout = nn.Dropout(dropout_rate) |
| self.linear = nn.Linear(input_dim, num_intent_labels) |
| self.sigmoid = nn.Sigmoid() |
| self.reset_params() |
|
|
| def forward(self, x): |
| x = self.dropout(x) |
| x = self.linear(x) |
| return self.sigmoid(x) |
|
|
| def reset_params(self): |
| nn.init.uniform_(self.linear.weight) |
| nn.init.uniform_(self.linear.bias) |
|
|
| class SlotClassifier(nn.Module): |
| def __init__(self, input_dim, num_slot_labels, dropout_rate=0.2): |
| super(SlotClassifier, self).__init__() |
| self.dropout = nn.Dropout(dropout_rate) |
| self.linear = nn.Linear(input_dim, num_slot_labels) |
|
|
| def forward(self, x): |
| x = self.dropout(x) |
| return self.linear(x) |
|
|
| class IntentTokenClassifier(nn.Module): |
| def __init__(self, input_dim, num_intent_labels, dropout_rate=0.): |
| super(IntentTokenClassifier, self).__init__() |
| self.dropout = nn.Dropout(dropout_rate) |
| self.linear = nn.Linear(input_dim, num_intent_labels) |
|
|
| def forward(self, x): |
| x = self.dropout(x) |
| return self.linear(x) |
|
|
| class TagIntentClassifier(nn.Module): |
| def __init__(self, input_dim, num_intent_labels, dropout_rate=0.): |
| super(TagIntentClassifier, self).__init__() |
| self.dropout = nn.Dropout(dropout_rate) |
| self.linear = nn.Linear(input_dim, num_intent_labels) |
| self.softmax = nn.Softmax(dim=1) |
|
|
| def forward(self, x): |
| x = self.dropout(x) |
| return self.softmax(self.linear(x)) |
|
|
|
|
| class BiaffineTagIntentClassifier(nn.Module): |
| """ |
| Biaffine Tag-Intent Classifier |
| score = h_cls^T U r + W [h_cls; r] + b |
| """ |
| def __init__(self, input_dim, num_intent_labels, dropout_rate=0.): |
| super(BiaffineTagIntentClassifier, self).__init__() |
| self.input_dim = input_dim |
| self.num_intent_labels = num_intent_labels |
|
|
| self.dropout = nn.Dropout(dropout_rate) |
|
|
| |
| self.U = nn.Parameter(torch.Tensor(num_intent_labels, input_dim, input_dim)) |
|
|
| |
| self.W = nn.Linear(2 * input_dim, num_intent_labels) |
|
|
| |
| self.softmax = nn.Softmax(dim=1) |
|
|
| self.reset_params() |
|
|
| def forward(self, h_cls, r): |
| """ |
| Args: |
| h_cls: [batch*num_mask, hidden_dim] - CLS representations |
| r: [batch*num_mask, hidden_dim] - tag intent vectors |
| |
| Returns: |
| [batch*num_mask, num_intent_labels] - probabilities |
| """ |
| h_cls = self.dropout(h_cls) |
| r = self.dropout(r) |
|
|
| |
| |
| |
| bilinear_scores = torch.einsum('bh,chd,bd->bc', h_cls, self.U, r) |
|
|
| |
| concat = torch.cat([h_cls, r], dim=1) |
| linear_scores = self.W(concat) |
|
|
| |
| scores = bilinear_scores + linear_scores |
|
|
| return self.softmax(scores) |
|
|
| def reset_params(self): |
| nn.init.xavier_uniform_(self.U) |
| nn.init.xavier_uniform_(self.W.weight) |
| nn.init.zeros_(self.W.bias) |
|
|
| class VSLIM(nn.Module): |
| """ |
| Features: |
| - Multi-intent classification with sigmoid |
| - Slot filling with optional CRF |
| - Intent token classification |
| - Tag-intent classification with B/BI masks |
| - Intent attention for tag-intent |
| """ |
|
|
| def __init__(self, |
| model_name, |
| num_slots, |
| num_intents, |
| num_token_intents, |
| num_tag_intents, |
| dropout=0.1, |
| use_crf=False, |
| num_mask=4, |
| cls_token_cat=True, |
| intent_attn=True, |
| use_biaffine_tag_intent=True, |
| args=None): |
| super().__init__() |
|
|
| |
| self.encoder = AutoModel.from_pretrained(model_name) |
| hidden_size = self.encoder.config.hidden_size |
|
|
| |
| self.multi_intent_classifier = MultiIntentClassifier(hidden_size, num_intents, dropout) |
| self.slot_classifier = SlotClassifier(hidden_size, num_slots, dropout) |
| self.intent_token_classifier = IntentTokenClassifier(hidden_size, num_token_intents, dropout) |
|
|
| |
| self.use_biaffine_tag_intent = use_biaffine_tag_intent |
|
|
| if use_biaffine_tag_intent: |
| |
| self.biaffine_tag_intent_classifier = BiaffineTagIntentClassifier( |
| hidden_size, num_tag_intents, dropout |
| ) |
| else: |
| |
| tag_input_dim = 2 * hidden_size if cls_token_cat else hidden_size |
| self.tag_intent_classifier = TagIntentClassifier(tag_input_dim, num_tag_intents, dropout) |
|
|
| if use_crf: |
| self.crf = CRF(num_tags=num_slots, batch_first=True) |
|
|
| self.use_crf = use_crf |
| self.num_mask = num_mask |
| self.cls_token_cat = cls_token_cat |
| self.intent_attn = intent_attn |
| self.num_intents = num_intents |
| self.args = args |
|
|
| def forward(self, input_ids, attention_mask, token_type_ids=None, |
| intent_label_ids=None, slot_labels_ids=None, |
| intent_token_ids=None, B_tag_mask=None, BI_tag_mask=None, |
| tag_intent_label=None): |
| |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| sequence_output = outputs.last_hidden_state |
| pooled_output = outputs.pooler_output |
|
|
| total_loss = 0 |
|
|
| |
| W_UTTINTENT = self.args.intent_loss_coef if self.args else 1.0 |
| W_SLOT = self.args.slot_loss_coef if self.args else 2.0 |
| |
| |
| W_TOKINTENT = self.args.token_intent_loss_coef if self.args else 2.0 |
| W_TAGINTENT = self.args.tag_intent_coef if self.args else 1.0 |
| IGNORE_INDEX = self.args.ignore_index if self.args else -100 |
|
|
| |
| intent_logits = self.multi_intent_classifier(pooled_output) |
|
|
| if intent_label_ids is not None: |
| intent_loss_fct = nn.BCELoss() |
| intent_loss = intent_loss_fct(intent_logits + 1e-10, intent_label_ids) |
| total_loss += W_UTTINTENT * intent_loss |
|
|
| |
| slot_logits = self.slot_classifier(sequence_output) |
|
|
| if slot_labels_ids is not None: |
| if self.use_crf: |
| slot_loss = self.crf(slot_logits, slot_labels_ids, mask=attention_mask.byte(), reduction='mean') |
| slot_loss = -1 * slot_loss |
| else: |
| slot_loss_fct = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) |
| if attention_mask is not None: |
| active_loss = attention_mask.view(-1) == 1 |
| active_logits = slot_logits.view(-1, slot_logits.size(-1))[active_loss] |
| active_labels = slot_labels_ids.view(-1)[active_loss] |
| slot_loss = slot_loss_fct(active_logits, active_labels) |
| else: |
| slot_loss = slot_loss_fct(slot_logits.view(-1, slot_logits.size(-1)), slot_labels_ids.view(-1)) |
|
|
| total_loss += W_SLOT * slot_loss |
|
|
| |
| intent_token_logits = self.intent_token_classifier(sequence_output) |
|
|
| intent_token_loss = 0.0 |
| if intent_token_ids is not None: |
| if self.use_crf: |
| intent_token_loss = self.crf(intent_token_logits, intent_token_ids, mask=attention_mask.byte(), reduction='mean') |
| intent_token_loss = -1 * intent_token_loss |
| else: |
| intent_token_loss_fct = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) |
| if attention_mask is not None: |
| active_intent_loss = attention_mask.view(-1) == 1 |
| active_intent_logits = intent_token_logits.view(-1, intent_token_logits.size(-1))[active_intent_loss] |
| active_intent_tokens = intent_token_ids.view(-1)[active_intent_loss] |
| intent_token_loss = intent_token_loss_fct(active_intent_logits, active_intent_tokens) |
| else: |
| intent_token_loss = intent_token_loss_fct(intent_token_logits.view(-1, intent_token_logits.size(-1)), intent_token_ids.view(-1)) |
|
|
| total_loss += W_TOKINTENT * intent_token_loss |
|
|
| |
| tag_intent_loss = 0.0 |
| tag_intent_logits = None |
|
|
| if B_tag_mask is not None and BI_tag_mask is not None and tag_intent_label is not None: |
| |
| |
| |
| |
| if BI_tag_mask.type() != torch.float32: |
| BI_tag_mask = BI_tag_mask.type(torch.float32) |
| if B_tag_mask.type() != torch.float32: |
| B_tag_mask = B_tag_mask.type(torch.float32) |
|
|
| |
| tag_intent_vec = torch.einsum('bml,bld->bmd', BI_tag_mask, sequence_output) |
|
|
| |
| if self.use_biaffine_tag_intent: |
| |
| h_cls = pooled_output.unsqueeze(1) |
| h_cls = h_cls.repeat(1, self.num_mask, 1) |
|
|
| |
| batch_size = h_cls.size(0) |
| h_cls_flat = h_cls.view(batch_size * self.num_mask, -1) |
| r_flat = tag_intent_vec.view(batch_size * self.num_mask, -1) |
|
|
| |
| tag_intent_logits = self.biaffine_tag_intent_classifier(h_cls_flat, r_flat) |
|
|
| else: |
| if self.cls_token_cat: |
| cls_token = pooled_output.unsqueeze(1) |
| cls_token = cls_token.repeat(1, self.num_mask, 1) |
| tag_intent_vec = torch.cat((cls_token, tag_intent_vec), dim=2) |
|
|
| tag_intent_vec = tag_intent_vec.view(tag_intent_vec.size(0) * tag_intent_vec.size(1), -1) |
| tag_intent_logits = self.tag_intent_classifier(tag_intent_vec) |
|
|
| if self.intent_attn: |
| intent_probs = intent_logits.unsqueeze(1) |
| intent_probs = intent_probs.repeat(1, self.num_mask, 1) |
| intent_probs = intent_probs.view(intent_probs.size(0) * intent_probs.size(1), -1) |
|
|
| |
| pad_probs = torch.zeros(intent_probs.size(0), 1, device=intent_probs.device) |
| intent_probs_expanded = torch.cat([pad_probs, intent_probs], dim=1) |
|
|
| |
| tag_intent_logits = tag_intent_logits * intent_probs_expanded |
| tag_intent_logits = tag_intent_logits.div(tag_intent_logits.sum(dim=1, keepdim=True) + 1e-10) |
|
|
| nll_fct = nn.NLLLoss(ignore_index=IGNORE_INDEX) |
| tag_intent_loss = nll_fct(torch.log(tag_intent_logits + 1e-10), tag_intent_label.view(-1)) |
| total_loss += W_TAGINTENT * tag_intent_loss |
|
|
| return { |
| "total_loss": total_loss, |
| "intent_loss": intent_loss if intent_label_ids is not None else 0, |
| "slot_loss": slot_loss if slot_labels_ids is not None else 0, |
| "intent_token_loss": intent_token_loss, |
| "tag_intent_loss": tag_intent_loss, |
| "intent_logits": intent_logits, |
| "slot_logits": slot_logits, |
| "intent_token_logits": intent_token_logits, |
| "tag_intent_logits": tag_intent_logits if B_tag_mask is not None else None |
| } |
|
|
|
|