|
|
| from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel |
| from torch import nn |
|
|
| class IntentClassifier(nn.Module): |
| def __init__(self, input_dim, num_intent_labels, dropout_rate=0.): |
| super(IntentClassifier, self).__init__() |
| self.dropout = nn.Dropout(dropout_rate) |
| self.linear = nn.Linear(input_dim, num_intent_labels) |
| |
| def forward(self, x): |
| x = self.dropout(x) |
| return self.linear(x) |
|
|
| class SlotClassifier(nn.Module): |
| def __init__(self, input_dim, num_slot_labels, dropout_rate=0.): |
| super(SlotClassifier, self).__init__() |
| self.dropout = nn.Dropout(dropout_rate) |
| self.linear = nn.Linear(input_dim, num_slot_labels) |
| |
| def forward(self, x): |
| x = self.dropout(x) |
| return self.linear(x) |
|
|
| class BertIDSF(BertPreTrainedModel): |
| def __init__(self, config, intent_label_lst, slot_label_lst, n_layers=1): |
| super().__init__(config) |
| self.num_intent_labels = len(intent_label_lst) |
| self.num_slot_labels = len(slot_label_lst) |
| self.bert = BertModel(config=config) |
| |
| |
| self.config.dict2 = {str(idx+1): label for idx, label in enumerate(slot_label_lst)} |
| self.config.inte2 = {str(idx+1): label for idx, label in enumerate(intent_label_lst)} |
| |
| classifier_dropout = ( |
| config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
| ) |
| self.dropout = nn.Dropout(classifier_dropout) |
| self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels) |
| self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels) |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| intents=None, |
| output_attentions=True, |
| lens=None, |
| device=None |
| ): |
| outputs = self.bert( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=True |
| ) |
| |
| sequence_output = outputs[0] |
| sequence_output = self.dropout(sequence_output) |
| |
| intent_logits = self.intent_classifier(sequence_output[:, 0, :]) |
| slot_logits = self.slot_classifier(sequence_output) |
| |
| total_loss = 0 |
| |
| |
| if intents is not None: |
| intent_loss_fct = nn.CrossEntropyLoss() |
| intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intent_labels), intents.view(-1)) |
| total_loss += 0.5 * intent_loss |
| |
| |
| if labels is not None: |
| slot_loss_fct = nn.CrossEntropyLoss(ignore_index=0) |
| |
| if attention_mask is not None: |
| active_loss = attention_mask.view(-1) == 1 |
| active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss] |
| active_labels = labels.view(-1)[active_loss] |
| slot_loss = slot_loss_fct(active_logits, active_labels) |
| else: |
| slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), labels.view(-1)) |
| total_loss += 0.5 * slot_loss |
| |
| outputs = ((intent_logits, slot_logits),) + outputs[2:] |
| outputs = (total_loss,) + outputs |
| |
| return outputs |
|
|