|
|
""" |
|
|
转换trocr 模型到自己数据集上的字符进行fine-tune |
|
|
""" |
|
|
import os |
|
|
import json |
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = '-1' |
|
|
import argparse |
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
|
from transformers import AutoConfig |
|
|
|
|
|
|
|
|
def read_vocab(vocab_path): |
|
|
""" |
|
|
读取自定义训练字符集 |
|
|
vocab_path format: |
|
|
1\n |
|
|
2\n |
|
|
... |
|
|
我\n |
|
|
""" |
|
|
other = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"] |
|
|
vocab = {} |
|
|
for ot in other: |
|
|
vocab[ot] = len(vocab) |
|
|
|
|
|
with open(vocab_path) as f: |
|
|
for line in f: |
|
|
line = line.strip('\n') |
|
|
if line not in vocab: |
|
|
vocab[line] = len(vocab) |
|
|
return vocab |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser(description='trocr fine-tune训练') |
|
|
|
|
|
parser.add_argument('--cust_vocab', default="./cust-data/vocab.txt", type=str, help="自定义训练数字符集") |
|
|
|
|
|
parser.add_argument('--pretrain_model', default='./weights', type=str, help="预训练bert权重文件") |
|
|
|
|
|
parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str, |
|
|
help="初始化训练权重,用于自己数据集上fine-tune权重") |
|
|
args = parser.parse_args() |
|
|
|
|
|
processor = TrOCRProcessor.from_pretrained(args.pretrain_model) |
|
|
pre_model = VisionEncoderDecoderModel.from_pretrained(args.pretrain_model) |
|
|
|
|
|
pre_vocab = processor.tokenizer.get_vocab() |
|
|
|
|
|
cust_vocab = read_vocab(args.cust_vocab) |
|
|
|
|
|
keep_tokens = [] |
|
|
unk_index = pre_vocab.get('<unk>') |
|
|
for key in cust_vocab: |
|
|
keep_tokens.append(pre_vocab.get(key, unk_index)) |
|
|
|
|
|
processor.save_pretrained(args.cust_data_init_weights_path) |
|
|
|
|
|
pre_model.save_pretrained(args.cust_data_init_weights_path) |
|
|
|
|
|
with open(os.path.join(args.cust_data_init_weights_path, "vocab.json"), "w") as f: |
|
|
f.write(json.dumps(cust_vocab, ensure_ascii=False)) |
|
|
|
|
|
|
|
|
with open(os.path.join(args.cust_data_init_weights_path, "config.json")) as f: |
|
|
model_config = json.load(f) |
|
|
|
|
|
|
|
|
model_config["decoder"]['vocab_size'] = len(cust_vocab) |
|
|
|
|
|
|
|
|
model_config['vocab_size'] = len(cust_vocab) |
|
|
|
|
|
with open(os.path.join(args.cust_data_init_weights_path, "config.json"), "w") as f: |
|
|
f.write(json.dumps(model_config, ensure_ascii=False)) |
|
|
|
|
|
|
|
|
cust_config = AutoConfig.from_pretrained(args.cust_data_init_weights_path) |
|
|
cust_model = VisionEncoderDecoderModel(cust_config) |
|
|
|
|
|
pre_model_weigths = pre_model.state_dict() |
|
|
cust_model_weigths = cust_model.state_dict() |
|
|
|
|
|
|
|
|
print("loading init weights..................") |
|
|
for key in pre_model_weigths: |
|
|
print("name:", key) |
|
|
if pre_model_weigths[key].shape != cust_model_weigths[key].shape: |
|
|
wt = pre_model_weigths[key][keep_tokens, :] |
|
|
cust_model_weigths[key] = wt |
|
|
else: |
|
|
cust_model_weigths[key] = pre_model_weigths[key] |
|
|
|
|
|
cust_model.load_state_dict(cust_model_weigths) |
|
|
cust_model.save_pretrained(args.cust_data_init_weights_path) |
|
|
|