Spaces:
Runtime error
Runtime error
| ''' | |
| Author: Qiguang Chen | |
| Date: 2023-01-11 10:39:26 | |
| LastEditors: Qiguang Chen | |
| LastEditTime: 2023-02-18 19:33:34 | |
| Description: | |
| ''' | |
| from common.utils import InputData | |
| from model.encoder.base_encoder import BaseEncoder, BiEncoder | |
| from model.encoder.pretrained_encoder import PretrainedEncoder | |
| from model.encoder.non_pretrained_encoder import NonPretrainedEncoder | |
| class AutoEncoder(BaseEncoder): | |
| def __init__(self, **config): | |
| """automatedly load encoder by 'encoder_name' | |
| Args: | |
| config (dict): | |
| encoder_name (str): support ["lstm", "self-attention-lstm", "bi-encoder"] and other pretrained model in hugging face | |
| **args (Any): other configuration items corresponding to each module. | |
| """ | |
| super().__init__() | |
| self.config = config | |
| if config.get("encoder_name"): | |
| encoder_name = config.get("encoder_name").lower() | |
| if encoder_name in ["lstm", "self-attention-lstm"]: | |
| self.__encoder = NonPretrainedEncoder(**config) | |
| elif encoder_name == "bi-encoder": | |
| self.__encoder= BiEncoder(self.__init__(**config["intent_encoder"]), self.__init__(**config["intent_encoder"])) | |
| else: | |
| self.__encoder = PretrainedEncoder(**config) | |
| else: | |
| raise ValueError("There is no Encoder Name in config.") | |
| def forward(self, inputs: InputData): | |
| return self.__encoder(inputs) |