""" 转换trocr 模型到自己数据集上的字符进行fine-tune """ import os import json os.environ["CUDA_VISIBLE_DEVICES"] = '-1' import argparse from transformers import TrOCRProcessor, VisionEncoderDecoderModel from transformers import AutoConfig def read_vocab(vocab_path): """ 读取自定义训练字符集 vocab_path format: 1\n 2\n ... 我\n """ other = ["", "", "", "", ""] vocab = {} for ot in other: vocab[ot] = len(vocab) with open(vocab_path) as f: for line in f: line = line.strip('\n') if line not in vocab: vocab[line] = len(vocab) return vocab if __name__ == '__main__': parser = argparse.ArgumentParser(description='trocr fine-tune训练') parser.add_argument('--cust_vocab', default="./cust-data/vocab.txt", type=str, help="自定义训练数字符集") parser.add_argument('--pretrain_model', default='./weights', type=str, help="预训练bert权重文件") parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str, help="初始化训练权重,用于自己数据集上fine-tune权重") args = parser.parse_args() processor = TrOCRProcessor.from_pretrained(args.pretrain_model) pre_model = VisionEncoderDecoderModel.from_pretrained(args.pretrain_model) pre_vocab = processor.tokenizer.get_vocab() cust_vocab = read_vocab(args.cust_vocab) keep_tokens = [] unk_index = pre_vocab.get('') for key in cust_vocab: keep_tokens.append(pre_vocab.get(key, unk_index)) processor.save_pretrained(args.cust_data_init_weights_path) pre_model.save_pretrained(args.cust_data_init_weights_path) ## 替换词库 with open(os.path.join(args.cust_data_init_weights_path, "vocab.json"), "w") as f: f.write(json.dumps(cust_vocab, ensure_ascii=False)) ##替换模型参数 with open(os.path.join(args.cust_data_init_weights_path, "config.json")) as f: model_config = json.load(f) ## 替换roberta embedding层词库 model_config["decoder"]['vocab_size'] = len(cust_vocab) ## 替换 attetion 字库 model_config['vocab_size'] = len(cust_vocab) with open(os.path.join(args.cust_data_init_weights_path, "config.json"), "w") as f: f.write(json.dumps(model_config, ensure_ascii=False)) ##加载cust model cust_config = AutoConfig.from_pretrained(args.cust_data_init_weights_path) cust_model = VisionEncoderDecoderModel(cust_config) pre_model_weigths = pre_model.state_dict() cust_model_weigths = cust_model.state_dict() ##权重初始化 print("loading init weights..................") for key in pre_model_weigths: print("name:", key) if pre_model_weigths[key].shape != cust_model_weigths[key].shape: wt = pre_model_weigths[key][keep_tokens, :] cust_model_weigths[key] = wt else: cust_model_weigths[key] = pre_model_weigths[key] cust_model.load_state_dict(cust_model_weigths) cust_model.save_pretrained(args.cust_data_init_weights_path)