soltaniali commited on
Commit
e19fe7e
·
verified ·
1 Parent(s): f88800f

Upload modeling.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling.py +97 -0
modeling.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
3
+ from torch import nn
4
+
5
+ class IntentClassifier(nn.Module):
6
+ def __init__(self, input_dim, num_intent_labels, dropout_rate=0.):
7
+ super(IntentClassifier, self).__init__()
8
+ self.dropout = nn.Dropout(dropout_rate)
9
+ self.linear = nn.Linear(input_dim, num_intent_labels)
10
+
11
+ def forward(self, x):
12
+ x = self.dropout(x)
13
+ return self.linear(x)
14
+
15
+ class SlotClassifier(nn.Module):
16
+ def __init__(self, input_dim, num_slot_labels, dropout_rate=0.):
17
+ super(SlotClassifier, self).__init__()
18
+ self.dropout = nn.Dropout(dropout_rate)
19
+ self.linear = nn.Linear(input_dim, num_slot_labels)
20
+
21
+ def forward(self, x):
22
+ x = self.dropout(x)
23
+ return self.linear(x)
24
+
25
+ class BertIDSF(BertPreTrainedModel):
26
+ def __init__(self, config, intent_label_lst, slot_label_lst, n_layers=1):
27
+ super().__init__(config)
28
+ self.num_intent_labels = len(intent_label_lst)
29
+ self.num_slot_labels = len(slot_label_lst)
30
+ self.bert = BertModel(config=config)
31
+
32
+ # Store dictionaries in config for later use
33
+ self.config.dict2 = {str(idx+1): label for idx, label in enumerate(slot_label_lst)}
34
+ self.config.inte2 = {str(idx+1): label for idx, label in enumerate(intent_label_lst)}
35
+
36
+ classifier_dropout = (
37
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
38
+ )
39
+ self.dropout = nn.Dropout(classifier_dropout)
40
+ self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels)
41
+ self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels)
42
+
43
+ def forward(
44
+ self,
45
+ input_ids=None,
46
+ attention_mask=None,
47
+ token_type_ids=None,
48
+ position_ids=None,
49
+ head_mask=None,
50
+ inputs_embeds=None,
51
+ labels=None,
52
+ intents=None,
53
+ output_attentions=True,
54
+ lens=None,
55
+ device=None
56
+ ):
57
+ outputs = self.bert(
58
+ input_ids,
59
+ attention_mask=attention_mask,
60
+ token_type_ids=token_type_ids,
61
+ position_ids=position_ids,
62
+ head_mask=head_mask,
63
+ inputs_embeds=inputs_embeds,
64
+ output_attentions=True
65
+ )
66
+
67
+ sequence_output = outputs[0]
68
+ sequence_output = self.dropout(sequence_output)
69
+
70
+ intent_logits = self.intent_classifier(sequence_output[:, 0, :])
71
+ slot_logits = self.slot_classifier(sequence_output)
72
+
73
+ total_loss = 0
74
+
75
+ # Intent Softmax
76
+ if intents is not None:
77
+ intent_loss_fct = nn.CrossEntropyLoss()
78
+ intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intent_labels), intents.view(-1))
79
+ total_loss += 0.5 * intent_loss
80
+
81
+ # Slot Softmax
82
+ if labels is not None:
83
+ slot_loss_fct = nn.CrossEntropyLoss(ignore_index=0)
84
+ # Only keep active parts of the loss
85
+ if attention_mask is not None:
86
+ active_loss = attention_mask.view(-1) == 1
87
+ active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
88
+ active_labels = labels.view(-1)[active_loss]
89
+ slot_loss = slot_loss_fct(active_logits, active_labels)
90
+ else:
91
+ slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), labels.view(-1))
92
+ total_loss += 0.5 * slot_loss
93
+
94
+ outputs = ((intent_logits, slot_logits),) + outputs[2:] # add hidden states and attention if they are here
95
+ outputs = (total_loss,) + outputs
96
+
97
+ return outputs # (loss), scores, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits