File size: 3,179 Bytes
ba0c78a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
转换trocr 模型到自己数据集上的字符进行fine-tune
"""
import os
import json

os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
import argparse
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from transformers import AutoConfig


def read_vocab(vocab_path):
    """
    读取自定义训练字符集
    vocab_path format:
    1\n
    2\n
    ...
    我\n
    """
    other = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
    vocab = {}
    for ot in other:
        vocab[ot] = len(vocab)

    with open(vocab_path) as f:
        for line in f:
            line = line.strip('\n')
            if line not in vocab:
                vocab[line] = len(vocab)
    return vocab


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='trocr fine-tune训练')

    parser.add_argument('--cust_vocab', default="./cust-data/vocab.txt", type=str, help="自定义训练数字符集")

    parser.add_argument('--pretrain_model', default='./weights', type=str, help="预训练bert权重文件")

    parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str,
                        help="初始化训练权重,用于自己数据集上fine-tune权重")
    args = parser.parse_args()

    processor = TrOCRProcessor.from_pretrained(args.pretrain_model)
    pre_model = VisionEncoderDecoderModel.from_pretrained(args.pretrain_model)

    pre_vocab = processor.tokenizer.get_vocab()

    cust_vocab = read_vocab(args.cust_vocab)

    keep_tokens = []
    unk_index = pre_vocab.get('<unk>')
    for key in cust_vocab:
        keep_tokens.append(pre_vocab.get(key, unk_index))

    processor.save_pretrained(args.cust_data_init_weights_path)

    pre_model.save_pretrained(args.cust_data_init_weights_path)
    ## 替换词库
    with open(os.path.join(args.cust_data_init_weights_path, "vocab.json"), "w") as f:
        f.write(json.dumps(cust_vocab, ensure_ascii=False))

    ##替换模型参数
    with open(os.path.join(args.cust_data_init_weights_path, "config.json")) as f:
        model_config = json.load(f)

    ## 替换roberta embedding层词库
    model_config["decoder"]['vocab_size'] = len(cust_vocab)

    ## 替换 attetion 字库
    model_config['vocab_size'] = len(cust_vocab)

    with open(os.path.join(args.cust_data_init_weights_path, "config.json"), "w") as f:
        f.write(json.dumps(model_config, ensure_ascii=False))

    ##加载cust model
    cust_config = AutoConfig.from_pretrained(args.cust_data_init_weights_path)
    cust_model = VisionEncoderDecoderModel(cust_config)

    pre_model_weigths = pre_model.state_dict()
    cust_model_weigths = cust_model.state_dict()

    ##权重初始化
    print("loading init weights..................")
    for key in pre_model_weigths:
        print("name:", key)
        if pre_model_weigths[key].shape != cust_model_weigths[key].shape:
            wt = pre_model_weigths[key][keep_tokens, :]
            cust_model_weigths[key] = wt
        else:
            cust_model_weigths[key] = pre_model_weigths[key]

    cust_model.load_state_dict(cust_model_weigths)
    cust_model.save_pretrained(args.cust_data_init_weights_path)