Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from common.utils import HiddenData, OutputData, InputData | |
| from model.decoder import BaseDecoder | |
| from model.decoder.interaction.gl_gin_interaction import LSTMEncoder | |
| class IntentEncoder(nn.Module): | |
| def __init__(self,input_dim, dropout_rate): | |
| super().__init__() | |
| self.dropout_rate = dropout_rate | |
| self.__intent_lstm = LSTMEncoder( | |
| input_dim, | |
| input_dim, | |
| dropout_rate | |
| ) | |
| def forward(self, g_hiddens, seq_lens): | |
| intent_lstm_out = self.__intent_lstm(g_hiddens, seq_lens) | |
| return F.dropout(intent_lstm_out, p=self.dropout_rate, training=self.training) | |
| class GLGINDecoder(BaseDecoder): | |
| def __init__(self, intent_classifier, slot_classifier, interaction=None, **config): | |
| super().__init__(intent_classifier, slot_classifier, interaction) | |
| self.config=config | |
| self.intent_encoder = IntentEncoder(self.intent_classifier.config["input_dim"], self.config["dropout_rate"]) | |
| def forward(self, hidden: HiddenData, forced_slot=None, forced_intent=None, differentiable=None): | |
| seq_lens = hidden.inputs.attention_mask.sum(-1) | |
| intent_lstm_out = self.intent_encoder(hidden.slot_hidden, seq_lens) | |
| hidden.update_intent_hidden_state(intent_lstm_out) | |
| pred_intent = self.intent_classifier(hidden) | |
| intent_index = self.intent_classifier.decode(OutputData(pred_intent, None),hidden.inputs, | |
| return_list=False, | |
| return_sentence_level=True) | |
| slot_hidden = self.interaction( | |
| hidden, | |
| pred_intent=pred_intent, | |
| intent_index=intent_index, | |
| ) | |
| pred_slot = self.slot_classifier(slot_hidden) | |
| num_intent = self.intent_classifier.config["intent_label_num"] | |
| pred_slot = pred_slot.classifier_output[:, num_intent:] | |
| return OutputData(pred_intent, F.log_softmax(pred_slot, dim=1)) |