Spaces:
Runtime error
Runtime error
| ''' | |
| Author: Qiguang Chen | |
| LastEditors: Qiguang Chen | |
| Date: 2023-02-13 10:44:39 | |
| LastEditTime: 2023-02-14 10:28:43 | |
| Description: | |
| ''' | |
| import os | |
| import dill | |
| from common import utils | |
| from common.utils import InputData, download | |
| from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer | |
| # parser = argparse.ArgumentParser() | |
| # parser.add_argument('--config_path', '-cp', type=str, default="config/reproduction/atis/joint_bert.yaml") | |
| # args = parser.parse_args() | |
| # config = Config.load_from_yaml(args.config_path) | |
| # config.base["train"] = False | |
| # config.base["test"] = False | |
| # model_manager = ModelManager(config) | |
| # model_manager.load() | |
| class PretrainedConfigForSLU(PretrainedConfig): | |
| def __init__(self, **kargs) -> None: | |
| super().__init__(**kargs) | |
| # pretrained_config = PretrainedConfigForSLU() | |
| # # pretrained_config.push_to_hub("xxxx") | |
| class PretrainedModelForSLU(PreTrainedModel): | |
| def __init__(self, config: PretrainedConfig, *inputs, **kwargs) -> None: | |
| super().__init__(config, *inputs, **kwargs) | |
| self.config_class = config | |
| self.model = utils.instantiate(config.model) | |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| cls.config_class = PretrainedConfigForSLU | |
| return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | |
| class PreTrainedTokenizerForSLU(PreTrainedTokenizer): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| dir_names = f"save/{pretrained_model_name_or_path}".split("/") | |
| dir_name = "" | |
| for name in dir_names: | |
| dir_name += name+"/" | |
| if not os.path.exists(dir_name): | |
| os.mkdir(dir_name) | |
| cache_path = f"./save/{pretrained_model_name_or_path}/tokenizer.pkl" | |
| if not os.path.exists(cache_path): | |
| download(f"https://huggingface.co/{pretrained_model_name_or_path}/resolve/main/tokenizer.pkl", cache_path) | |
| with open(cache_path, "rb") as f: | |
| tokenizer = dill.load(f) | |
| return tokenizer | |
| # pretrained_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| # pretrained_tokenizer = PreTrainedTokenizerForSLU.from_pretrained("LightChen2333/joint-bert-slu-atis") | |
| # test_model = PretrainedModelForSLU.from_pretrained("LightChen2333/joint-bert-slu-atis") | |
| # print(test_model(InputData([pretrained_tokenizer("I want to go to Beijing !")]))) |