Spaces:
Runtime error
Runtime error
| ''' | |
| Author: Qiguang Chen | |
| Date: 2023-01-11 10:39:26 | |
| LastEditors: Qiguang Chen | |
| LastEditTime: 2023-02-18 17:38:30 | |
| Description: pretrained encoder model | |
| ''' | |
| from transformers import AutoModel, AutoConfig | |
| from common import utils | |
| from common.utils import InputData, HiddenData | |
| from model.encoder.base_encoder import BaseEncoder | |
| class PretrainedEncoder(BaseEncoder): | |
| def __init__(self, **config): | |
| """ init pretrained encoder | |
| Args: | |
| config (dict): | |
| encoder_name (str): pretrained model name in hugging face. | |
| """ | |
| super().__init__(**config) | |
| if self.config.get("_is_check_point_"): | |
| self.encoder = utils.instantiate(config["pretrained_model"], target="_pretrained_model_target_") | |
| # print(self.encoder) | |
| else: | |
| self.encoder = AutoModel.from_pretrained(config["encoder_name"]) | |
| def forward(self, inputs: InputData): | |
| output = self.encoder(**inputs.get_inputs()) | |
| hidden = HiddenData(None, output.last_hidden_state) | |
| if self.config.get("return_with_input"): | |
| hidden.add_input(inputs) | |
| if self.config.get("return_sentence_level_hidden"): | |
| padding_side = self.config.get("padding_side") | |
| if hasattr(output, "pooler_output"): | |
| hidden.update_intent_hidden_state(output.pooler_output) | |
| elif padding_side is not None and padding_side == "left": | |
| hidden.update_intent_hidden_state(output.last_hidden_state[:, -1]) | |
| else: | |
| hidden.update_intent_hidden_state(output.last_hidden_state[:, 0]) | |
| return hidden | |