Spaces:
Runtime error
Runtime error
| ''' | |
| Author: Qiguang Chen | |
| Date: 2023-01-11 10:39:26 | |
| LastEditors: Qiguang Chen | |
| LastEditTime: 2023-01-31 18:22:36 | |
| Description: | |
| ''' | |
| from torch import nn | |
| from common.utils import HiddenData, OutputData, InputData | |
| class BaseDecoder(nn.Module): | |
| """Base class for all decoder module. | |
| Notice: t is often only necessary to change this module and its sub-modules | |
| """ | |
| def __init__(self, intent_classifier=None, slot_classifier=None, interaction=None): | |
| super().__init__() | |
| self.intent_classifier = intent_classifier | |
| self.slot_classifier = slot_classifier | |
| self.interaction = interaction | |
| def forward(self, hidden: HiddenData): | |
| """forward | |
| Args: | |
| hidden (HiddenData): encoded data | |
| Returns: | |
| OutputData: prediction logits | |
| """ | |
| if self.interaction is not None: | |
| hidden = self.interaction(hidden) | |
| intent = None | |
| slot = None | |
| if self.intent_classifier is not None: | |
| intent = self.intent_classifier(hidden) | |
| if self.slot_classifier is not None: | |
| slot = self.slot_classifier(hidden) | |
| return OutputData(intent, slot) | |
| def decode(self, output: OutputData, target: InputData = None): | |
| """decode output logits | |
| Args: | |
| output (OutputData): output logits data | |
| target (InputData, optional): input data with attention mask. Defaults to None. | |
| Returns: | |
| List: decoded sequence ids | |
| """ | |
| intent, slot = None, None | |
| if self.intent_classifier is not None: | |
| intent = self.intent_classifier.decode(output, target) | |
| if self.slot_classifier is not None: | |
| slot = self.slot_classifier.decode(output, target) | |
| return OutputData(intent, slot) | |
| def compute_loss(self, pred: OutputData, target: InputData, compute_intent_loss=True, compute_slot_loss=True): | |
| """compute loss. | |
| Notice: can set intent and slot loss weight by adding 'weight' config item in corresponding classifier configuration. | |
| Args: | |
| pred (OutputData): output logits data | |
| target (InputData): input golden data | |
| compute_intent_loss (bool, optional): whether to compute intent loss. Defaults to True. | |
| compute_slot_loss (bool, optional): whether to compute intent loss. Defaults to True. | |
| Returns: | |
| Tensor: loss result | |
| """ | |
| loss = 0 | |
| intent_loss = None | |
| slot_loss = None | |
| if self.intent_classifier is not None: | |
| intent_loss = self.intent_classifier.compute_loss(pred, target) if compute_intent_loss else None | |
| intent_weight = self.intent_classifier.config.get("weight") | |
| intent_weight = intent_weight if intent_weight is not None else 1. | |
| loss += intent_loss * intent_weight | |
| if self.slot_classifier is not None: | |
| slot_loss = self.slot_classifier.compute_loss(pred, target) if compute_slot_loss else None | |
| slot_weight = self.slot_classifier.config.get("weight") | |
| slot_weight = slot_weight if slot_weight is not None else 1. | |
| loss += slot_loss * slot_weight | |
| return loss, intent_loss, slot_loss | |
| class StackPropagationDecoder(BaseDecoder): | |
| def forward(self, hidden: HiddenData): | |
| # hidden = self.interaction(hidden) | |
| pred_intent = self.intent_classifier(hidden) | |
| # embedding = pred_intent.output_embedding if pred_intent.output_embedding is not None else pred_intent.classifier_output | |
| # hidden.update_intent_hidden_state(torch.cat([hidden.get_slot_hidden_state(), embedding], dim=-1)) | |
| hidden = self.interaction(pred_intent, hidden) | |
| pred_slot = self.slot_classifier(hidden) | |
| return OutputData(pred_intent, pred_slot) | |
| class DCANetDecoder(BaseDecoder): | |
| def forward(self, hidden: HiddenData): | |
| if self.interaction is not None: | |
| hidden = self.interaction(hidden, intent_emb=self.intent_classifier, slot_emb=self.slot_classifier) | |
| return OutputData(self.intent_classifier(hidden), self.slot_classifier(hidden)) | |