Spaces:
Runtime error
Runtime error
| ''' | |
| Author: Qiguang Chen | |
| Date: 2023-01-11 10:39:26 | |
| LastEditors: Qiguang Chen | |
| LastEditTime: 2023-01-31 20:07:00 | |
| Description: | |
| ''' | |
| import random | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torch.nn import CrossEntropyLoss | |
| from model.decoder import decoder_utils | |
| from torchcrf import CRF | |
| from common.utils import HiddenData, OutputData, InputData, ClassifierOutputData, unpack_sequence, pack_sequence, \ | |
| instantiate | |
| class BaseClassifier(nn.Module): | |
| """Base class for all classifier module | |
| """ | |
| def __init__(self, **config): | |
| super().__init__() | |
| self.config = config | |
| if config.get("loss_fn"): | |
| self.loss_fn = instantiate(config.get("loss_fn")) | |
| else: | |
| self.loss_fn = CrossEntropyLoss(ignore_index=self.config.get("ignore_index")) | |
| def forward(self, *args, **kwargs): | |
| raise NotImplementedError("No implemented classifier.") | |
| def decode(self, output: OutputData, | |
| target: InputData = None, | |
| return_list=True, | |
| return_sentence_level=None): | |
| """decode output logits | |
| Args: | |
| output (OutputData): output logits data | |
| target (InputData, optional): input data with attention mask. Defaults to None. | |
| return_list (bool, optional): if True return list else return torch Tensor.. Defaults to True. | |
| return_sentence_level (_type_, optional): if True decode sentence level intent else decode token level intent. Defaults to None. | |
| Returns: | |
| List or Tensor: decoded sequence ids | |
| """ | |
| if self.config.get("return_sentence_level") is not None and return_sentence_level is None: | |
| return_sentence_level = self.config.get("return_sentence_level") | |
| elif self.config.get("return_sentence_level") is None and return_sentence_level is None: | |
| return_sentence_level = False | |
| return decoder_utils.decode(output, target, | |
| return_list=return_list, | |
| return_sentence_level=return_sentence_level, | |
| pred_type=self.config.get("mode"), | |
| use_multi=self.config.get("use_multi"), | |
| multi_threshold=self.config.get("multi_threshold")) | |
| def compute_loss(self, pred: OutputData, target: InputData): | |
| """compute loss | |
| Args: | |
| pred (OutputData): output logits data | |
| target (InputData): input golden data | |
| Returns: | |
| Tensor: loss result | |
| """ | |
| _CRF = None | |
| if self.config.get("use_crf"): | |
| _CRF = self.CRF | |
| return decoder_utils.compute_loss(pred, target, criterion_type=self.config["mode"], | |
| use_crf=_CRF is not None, | |
| ignore_index=self.config["ignore_index"], | |
| use_multi=self.config.get("use_multi"), | |
| loss_fn=self.loss_fn, | |
| CRF=_CRF) | |
| class LinearClassifier(BaseClassifier): | |
| """ | |
| Decoder structure based on Linear. | |
| """ | |
| def __init__(self, **config): | |
| """Construction function for LinearClassifier | |
| Args: | |
| config (dict): | |
| input_dim (int): hidden state dim. | |
| use_slot (bool): whether to classify slot label. | |
| slot_label_num (int, optional): the number of slot label. Enabled if use_slot is True. | |
| use_intent (bool): whether to classify intent label. | |
| intent_label_num (int, optional): the number of intent label. Enabled if use_intent is True. | |
| use_crf (bool): whether to use crf for slot. | |
| """ | |
| super().__init__(**config) | |
| self.config = config | |
| if config.get("use_slot"): | |
| self.slot_classifier = nn.Linear(config["input_dim"], config["slot_label_num"]) | |
| if self.config.get("use_crf"): | |
| self.CRF = CRF(num_tags=config["slot_label_num"], batch_first=True) | |
| if config.get("use_intent"): | |
| self.intent_classifier = nn.Linear(config["input_dim"], config["intent_label_num"]) | |
| def forward(self, hidden: HiddenData): | |
| if self.config.get("use_intent"): | |
| return ClassifierOutputData(self.intent_classifier(hidden.get_intent_hidden_state())) | |
| if self.config.get("use_slot"): | |
| return ClassifierOutputData(self.slot_classifier(hidden.get_slot_hidden_state())) | |
| class AutoregressiveLSTMClassifier(BaseClassifier): | |
| """ | |
| Decoder structure based on unidirectional LSTM. | |
| """ | |
| def __init__(self, **config): | |
| """ Construction function for Decoder. | |
| Args: | |
| config (dict): | |
| input_dim (int): input dimension of Decoder. In fact, it's encoder hidden size. | |
| use_slot (bool): whether to classify slot label. | |
| slot_label_num (int, optional): the number of slot label. Enabled if use_slot is True. | |
| use_intent (bool): whether to classify intent label. | |
| intent_label_num (int, optional): the number of intent label. Enabled if use_intent is True. | |
| use_crf (bool): whether to use crf for slot. | |
| hidden_dim (int): hidden dimension of iterative LSTM. | |
| embedding_dim (int): if it's not None, the input and output are relevant. | |
| dropout_rate (float): dropout rate of network which is only useful for embedding. | |
| """ | |
| super(AutoregressiveLSTMClassifier, self).__init__(**config) | |
| if config.get("use_slot") and config.get("use_crf"): | |
| self.CRF = CRF(num_tags=config["slot_label_num"], batch_first=True) | |
| self.input_dim = config["input_dim"] | |
| self.hidden_dim = config["hidden_dim"] | |
| if config.get("use_intent"): | |
| self.output_dim = config["intent_label_num"] | |
| if config.get("use_slot"): | |
| self.output_dim = config["slot_label_num"] | |
| self.dropout_rate = config["dropout_rate"] | |
| self.embedding_dim = config.get("embedding_dim") | |
| self.force_ratio = config.get("force_ratio") | |
| self.config = config | |
| self.ignore_index = config.get("ignore_index") if config.get("ignore_index") is not None else -100 | |
| # If embedding_dim is not None, the output and input | |
| # of this structure is relevant. | |
| if self.embedding_dim is not None: | |
| self.embedding_layer = nn.Embedding(self.output_dim, self.embedding_dim) | |
| self.init_tensor = nn.Parameter( | |
| torch.randn(1, self.embedding_dim), | |
| requires_grad=True | |
| ) | |
| # Make sure the input dimension of iterative LSTM. | |
| if self.embedding_dim is not None: | |
| lstm_input_dim = self.input_dim + self.embedding_dim | |
| else: | |
| lstm_input_dim = self.input_dim | |
| # Network parameter definition. | |
| self.dropout_layer = nn.Dropout(self.dropout_rate) | |
| self.lstm_layer = nn.LSTM( | |
| input_size=lstm_input_dim, | |
| hidden_size=self.hidden_dim, | |
| batch_first=True, | |
| bidirectional=self.config["bidirectional"], | |
| dropout=self.dropout_rate, | |
| num_layers=self.config["layer_num"] | |
| ) | |
| self.linear_layer = nn.Linear( | |
| self.hidden_dim, | |
| self.output_dim | |
| ) | |
| # self.loss_fn = CrossEntropyLoss(ignore_index=self.ignore_index) | |
| def forward(self, hidden: HiddenData, internal_interaction=None, **interaction_args): | |
| """ Forward process for decoder. | |
| :param internal_interaction: | |
| :param hidden: | |
| :return: is distribution of prediction labels. | |
| """ | |
| input_tensor = hidden.slot_hidden | |
| seq_lens = hidden.inputs.attention_mask.sum(-1).detach().cpu().tolist() | |
| output_tensor_list, sent_start_pos = [], 0 | |
| input_tensor = pack_sequence(input_tensor, seq_lens) | |
| forced_input = None | |
| if self.training: | |
| if random.random() < self.force_ratio: | |
| if self.config["mode"]=="slot": | |
| forced_slot = pack_sequence(hidden.inputs.slot, seq_lens) | |
| temp_slot = [] | |
| for index, x in enumerate(forced_slot): | |
| if index == 0: | |
| temp_slot.append(x.reshape(1)) | |
| elif x == self.ignore_index: | |
| temp_slot.append(temp_slot[-1]) | |
| else: | |
| temp_slot.append(x.reshape(1)) | |
| forced_input = torch.cat(temp_slot, 0) | |
| if self.config["mode"]=="token-level-intent": | |
| forced_intent = hidden.inputs.intent.unsqueeze(1).repeat(1, hidden.inputs.slot.shape[1]) | |
| forced_input = pack_sequence(forced_intent, seq_lens) | |
| if self.embedding_dim is None or forced_input is not None: | |
| for sent_i in range(0, len(seq_lens)): | |
| sent_end_pos = sent_start_pos + seq_lens[sent_i] | |
| # Segment input hidden tensors. | |
| seg_hiddens = input_tensor[sent_start_pos: sent_end_pos, :] | |
| if self.embedding_dim is not None and forced_input is not None: | |
| if seq_lens[sent_i] > 1: | |
| seg_forced_input = forced_input[sent_start_pos: sent_end_pos] | |
| seg_forced_tensor = self.embedding_layer(seg_forced_input)[:-1] | |
| seg_prev_tensor = torch.cat([self.init_tensor, seg_forced_tensor], dim=0) | |
| else: | |
| seg_prev_tensor = self.init_tensor | |
| # Concatenate forced target tensor. | |
| combined_input = torch.cat([seg_hiddens, seg_prev_tensor], dim=1) | |
| else: | |
| combined_input = seg_hiddens | |
| dropout_input = self.dropout_layer(combined_input) | |
| lstm_out, _ = self.lstm_layer(dropout_input.view(1, seq_lens[sent_i], -1)) | |
| if internal_interaction is not None: | |
| interaction_args["sent_id"] = sent_i | |
| lstm_out = internal_interaction(torch.transpose(lstm_out, 0, 1), **interaction_args)[:, 0] | |
| linear_out = self.linear_layer(lstm_out.view(seq_lens[sent_i], -1)) | |
| output_tensor_list.append(linear_out) | |
| sent_start_pos = sent_end_pos | |
| else: | |
| for sent_i in range(0, len(seq_lens)): | |
| prev_tensor = self.init_tensor | |
| # It's necessary to remember h and c state | |
| # when output prediction every single step. | |
| last_h, last_c = None, None | |
| sent_end_pos = sent_start_pos + seq_lens[sent_i] | |
| for word_i in range(sent_start_pos, sent_end_pos): | |
| seg_input = input_tensor[[word_i], :] | |
| combined_input = torch.cat([seg_input, prev_tensor], dim=1) | |
| dropout_input = self.dropout_layer(combined_input).view(1, 1, -1) | |
| if last_h is None and last_c is None: | |
| lstm_out, (last_h, last_c) = self.lstm_layer(dropout_input) | |
| else: | |
| lstm_out, (last_h, last_c) = self.lstm_layer(dropout_input, (last_h, last_c)) | |
| if internal_interaction is not None: | |
| interaction_args["sent_id"] = sent_i | |
| lstm_out = internal_interaction(lstm_out, **interaction_args)[:, 0] | |
| lstm_out = self.linear_layer(lstm_out.view(1, -1)) | |
| output_tensor_list.append(lstm_out) | |
| _, index = lstm_out.topk(1, dim=1) | |
| prev_tensor = self.embedding_layer(index).view(1, -1) | |
| sent_start_pos = sent_end_pos | |
| seq_unpacked = unpack_sequence(torch.cat(output_tensor_list, dim=0), seq_lens) | |
| # TODO: 都支持softmax | |
| if self.config.get("use_multi"): | |
| pred_output = ClassifierOutputData(seq_unpacked) | |
| else: | |
| pred_output = ClassifierOutputData(F.log_softmax(seq_unpacked, dim=-1)) | |
| return pred_output | |
| class MLPClassifier(BaseClassifier): | |
| """ | |
| Decoder structure based on MLP. | |
| """ | |
| def __init__(self, **config): | |
| """ Construction function for Decoder. | |
| Args: | |
| config (dict): | |
| use_slot (bool): whether to classify slot label. | |
| use_intent (bool): whether to classify intent label. | |
| mlp (List): | |
| - _model_target_: torch.nn.Linear | |
| in_features (int): input feature dim | |
| out_features (int): output feature dim | |
| - _model_target_: torch.nn.LeakyReLU | |
| negative_slope: 0.2 | |
| - ... | |
| """ | |
| super(MLPClassifier, self).__init__(**config) | |
| self.config = config | |
| for i, x in enumerate(config["mlp"]): | |
| if isinstance(x.get("in_features"), str): | |
| config["mlp"][i]["in_features"] = self.config[x["in_features"][1:-1]] | |
| if isinstance(x.get("out_features"), str): | |
| config["mlp"][i]["out_features"] = self.config[x["out_features"][1:-1]] | |
| mlp = [instantiate(x) for x in config["mlp"]] | |
| self.seq = nn.Sequential(*mlp) | |
| def forward(self, hidden: HiddenData): | |
| if self.config.get("use_intent"): | |
| res = self.seq(hidden.intent_hidden) | |
| else: | |
| res = self.seq(hidden.slot_hidden) | |
| return ClassifierOutputData(res) | |